mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-06-26 18:18:11 +00:00
Merge 25b71e16d9
into 8a70a8cd6f
This commit is contained in:
commit
558f7b77c2
52
train.py
52
train.py
@ -11,8 +11,9 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
from random import randint
|
from random import randint
|
||||||
from utils.loss_utils import l1_loss, ssim
|
from utils.loss_utils import l1_loss, ssim, ssim_optimized, create_window
|
||||||
from gaussian_renderer import render, network_gui
|
from gaussian_renderer import render, network_gui
|
||||||
import sys
|
import sys
|
||||||
from scene import Scene, GaussianModel
|
from scene import Scene, GaussianModel
|
||||||
@ -29,27 +30,28 @@ except ImportError:
|
|||||||
TENSORBOARD_FOUND = False
|
TENSORBOARD_FOUND = False
|
||||||
|
|
||||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
||||||
|
start_time=time.time()
|
||||||
first_iter = 0
|
first_iter = 0
|
||||||
tb_writer = prepare_output_and_logger(dataset)
|
tb_writer = prepare_output_and_logger(dataset) # Tensorboard writer
|
||||||
gaussians = GaussianModel(dataset.sh_degree)
|
gaussians = GaussianModel(dataset.sh_degree)
|
||||||
scene = Scene(dataset, gaussians)
|
scene = Scene(dataset, gaussians)
|
||||||
gaussians.training_setup(opt)
|
gaussians.training_setup(opt)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
(model_params, first_iter) = torch.load(checkpoint)
|
(model_params, first_iter) = torch.load(checkpoint)
|
||||||
gaussians.restore(model_params, opt)
|
gaussians.restore(model_params, opt)
|
||||||
|
|
||||||
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
||||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
iter_start = torch.cuda.Event(enable_timing = True)
|
iter_start = torch.cuda.Event(enable_timing = True)
|
||||||
iter_end = torch.cuda.Event(enable_timing = True)
|
iter_end = torch.cuda.Event(enable_timing = True)
|
||||||
|
|
||||||
viewpoint_stack = None
|
viewpoint_stack = None
|
||||||
ema_loss_for_log = 0.0
|
ema_loss_for_log = 0.0
|
||||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||||
first_iter += 1
|
first_iter += 1
|
||||||
for iteration in range(first_iter, opt.iterations + 1):
|
for iteration in range(first_iter, opt.iterations + 1):
|
||||||
if network_gui.conn == None:
|
if network_gui.conn == None:
|
||||||
network_gui.try_connect()
|
network_gui.try_connect()
|
||||||
while network_gui.conn != None:
|
while network_gui.conn != None:
|
||||||
try:
|
try:
|
||||||
@ -64,9 +66,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
network_gui.conn = None
|
network_gui.conn = None
|
||||||
|
|
||||||
iter_start.record()
|
iter_start.record()
|
||||||
|
|
||||||
gaussians.update_learning_rate(iteration)
|
gaussians.update_learning_rate(iteration)
|
||||||
|
|
||||||
# Every 1000 its we increase the levels of SH up to a maximum degree
|
# Every 1000 its we increase the levels of SH up to a maximum degree
|
||||||
if iteration % 1000 == 0:
|
if iteration % 1000 == 0:
|
||||||
@ -83,14 +85,22 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
|
|
||||||
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
||||||
|
|
||||||
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
||||||
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
|
# gt_image = viewpoint_cam.original_image.cuda()
|
||||||
|
# Ll1 = l1_loss(image, gt_image)
|
||||||
|
# loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
||||||
|
# loss.backward()
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------modify-------------
|
||||||
gt_image = viewpoint_cam.original_image.cuda()
|
gt_image = viewpoint_cam.original_image.cuda()
|
||||||
Ll1 = l1_loss(image, gt_image)
|
Ll1 = l1_loss(image, gt_image)
|
||||||
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_optimized(image, gt_image, window=window))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
#-------------------------------------
|
||||||
|
|
||||||
iter_end.record()
|
iter_end.record()
|
||||||
|
|
||||||
@ -131,13 +141,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
||||||
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.")
|
||||||
|
|
||||||
|
|
||||||
def prepare_output_and_logger(args):
|
def prepare_output_and_logger(args):
|
||||||
if not args.model_path:
|
if not args.model_path:
|
||||||
if os.getenv('OAR_JOB_ID'):
|
if os.getenv('OAR_JOB_ID'):
|
||||||
unique_str=os.getenv('OAR_JOB_ID')
|
unique_str=os.getenv('OAR_JOB_ID')
|
||||||
else:
|
else:
|
||||||
unique_str = str(uuid.uuid4())
|
unique_str = str(uuid.uuid4())
|
||||||
args.model_path = os.path.join("./output/", unique_str[0:10])
|
args.model_path = os.path.join("/mnt/data1/3dgs_modify_output/", unique_str[0:10])
|
||||||
|
|
||||||
# Set up output folder
|
# Set up output folder
|
||||||
print("Output folder: {}".format(args.model_path))
|
print("Output folder: {}".format(args.model_path))
|
||||||
@ -191,6 +206,11 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
#----------------------create window------------------
|
||||||
|
window_size=11
|
||||||
|
channel=3
|
||||||
|
window=create_window(window_size, channel)
|
||||||
|
#--------------------------------
|
||||||
# Set up command line argument parser
|
# Set up command line argument parser
|
||||||
parser = ArgumentParser(description="Training script parameters")
|
parser = ArgumentParser(description="Training script parameters")
|
||||||
lp = ModelParams(parser)
|
lp = ModelParams(parser)
|
||||||
|
@ -25,13 +25,14 @@ def gaussian(window_size, sigma):
|
|||||||
return gauss / gauss.sum()
|
return gauss / gauss.sum()
|
||||||
|
|
||||||
def create_window(window_size, channel):
|
def create_window(window_size, channel):
|
||||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
"""Create a 2D Gaussian window."""
|
||||||
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||||
return window
|
return window
|
||||||
|
|
||||||
def ssim(img1, img2, window_size=11, size_average=True):
|
def ssim(img1, img2, window_size=11, size_average=True):
|
||||||
channel = img1.size(-3)
|
channel = img1.size(-3) #channel=3
|
||||||
window = create_window(window_size, channel)
|
window = create_window(window_size, channel)
|
||||||
|
|
||||||
if img1.is_cuda:
|
if img1.is_cuda:
|
||||||
@ -62,3 +63,12 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|||||||
else:
|
else:
|
||||||
return ssim_map.mean(1).mean(1).mean(1)
|
return ssim_map.mean(1).mean(1).mean(1)
|
||||||
|
|
||||||
|
#-----------------modify------------------------------------
|
||||||
|
def ssim_optimized(img1, img2, window=None, window_size=11, size_average=True):
|
||||||
|
channel = img1.size(-3)
|
||||||
|
if window is None:
|
||||||
|
window = create_window(window_size, channel).to(img1.device).type_as(img1)
|
||||||
|
if img1.is_cuda:
|
||||||
|
window = window.cuda(img1.get_device())
|
||||||
|
window = window.type_as(img1)
|
||||||
|
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||||
|
Loading…
Reference in New Issue
Block a user