Update train.py

Updated the main function, call the create_window() function for just once, since the window_size is fixed to 11 and the channel is fixed to 3 according to the source code.
This commit is contained in:
Lixing Xiao 2024-07-12 16:45:25 +08:00 committed by GitHub
parent 3a3220c1fe
commit c7ea596c57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,8 +11,9 @@
import os import os
import torch import torch
import time
from random import randint from random import randint
from utils.loss_utils import l1_loss, ssim from utils.loss_utils import l1_loss, ssim, ssim_optimized, create_window
from gaussian_renderer import render, network_gui from gaussian_renderer import render, network_gui
import sys import sys
from scene import Scene, GaussianModel from scene import Scene, GaussianModel
@ -29,20 +30,21 @@ except ImportError:
TENSORBOARD_FOUND = False TENSORBOARD_FOUND = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
start_time=time.time()
first_iter = 0 first_iter = 0
tb_writer = prepare_output_and_logger(dataset) tb_writer = prepare_output_and_logger(dataset) # Tensorboard writer
gaussians = GaussianModel(dataset.sh_degree) gaussians = GaussianModel(dataset.sh_degree) #高斯模型
scene = Scene(dataset, gaussians) scene = Scene(dataset, gaussians) #场景
gaussians.training_setup(opt) gaussians.training_setup(opt) #训练设置
if checkpoint: if checkpoint:
(model_params, first_iter) = torch.load(checkpoint) (model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt) gaussians.restore(model_params, opt)
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] #背景颜色
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") #背景颜色
iter_start = torch.cuda.Event(enable_timing = True) iter_start = torch.cuda.Event(enable_timing = True) #开始时间
iter_end = torch.cuda.Event(enable_timing = True) iter_end = torch.cuda.Event(enable_timing = True) #结束时间
viewpoint_stack = None viewpoint_stack = None
ema_loss_for_log = 0.0 ema_loss_for_log = 0.0
@ -64,9 +66,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
except Exception as e: except Exception as e:
network_gui.conn = None network_gui.conn = None
iter_start.record() iter_start.record() #记录开始时间
gaussians.update_learning_rate(iteration) gaussians.update_learning_rate(iteration) #更新学习率
# Every 1000 its we increase the levels of SH up to a maximum degree # Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0: if iteration % 1000 == 0:
@ -83,14 +85,22 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
bg = torch.rand((3), device="cuda") if opt.random_background else background bg = torch.rand((3), device="cuda") if opt.random_background else background
render_pkg = render(viewpoint_cam, gaussians, pipe, bg) render_pkg = render(viewpoint_cam, gaussians, pipe, bg) #渲染
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# Loss # Loss
# gt_image = viewpoint_cam.original_image.cuda()
# Ll1 = l1_loss(image, gt_image)
# loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
# loss.backward()
# ----------------modify-------------
gt_image = viewpoint_cam.original_image.cuda() gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image) Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_optimized(image, gt_image, window=window))
loss.backward() loss.backward()
#-------------------------------------
iter_end.record() iter_end.record()
@ -131,13 +141,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
print("\n[ITER {}] Saving Checkpoint".format(iteration)) print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
end_time = time.time() # 记录训练结束时间
total_time = end_time - start_time # 计算总时间
print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.") # 打印总时间
def prepare_output_and_logger(args): def prepare_output_and_logger(args):
if not args.model_path: if not args.model_path:
if os.getenv('OAR_JOB_ID'): if os.getenv('OAR_JOB_ID'):
unique_str=os.getenv('OAR_JOB_ID') unique_str=os.getenv('OAR_JOB_ID')
else: else:
unique_str = str(uuid.uuid4()) unique_str = str(uuid.uuid4())
args.model_path = os.path.join("./output/", unique_str[0:10]) args.model_path = os.path.join("/mnt/data1/3dgs_modify_output/", unique_str[0:10])
# Set up output folder # Set up output folder
print("Output folder: {}".format(args.model_path)) print("Output folder: {}".format(args.model_path))
@ -191,6 +206,11 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
torch.cuda.empty_cache() torch.cuda.empty_cache()
if __name__ == "__main__": if __name__ == "__main__":
#----------------------create window------------------
window_size=11
channel=3
window=create_window(window_size, channel)
#--------------------------------
# Set up command line argument parser # Set up command line argument parser
parser = ArgumentParser(description="Training script parameters") parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser) lp = ModelParams(parser)