gaussian-splatting/train.py
2024-05-13 21:42:34 +08:00

256 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import numpy as np
import torch
from random import randint
from PIL import Image
from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
# 尝试导入 PyTorch 提供的 TensorBoard 记录器 SummaryWriter 类
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
'''
dataset: 只存储与Moedl相关参数的args
opt: 优化相关参数
pipe: 管道相关参数
checkpoint: 已训练模型的路径
debug_from: 从哪一个迭代开始debug
'''
first_iter = 0
# 创建保存结果的文件夹并保存模型相关的参数到cfg_args文件尝试创建tensorboard_writer记录训练过程
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree) # 创建初始化高斯模型用于表示场景中的每个点的3D高斯分布
scene = Scene(dataset, gaussians) # 创建初始3D场景对象加载数据集和对应的相机参数
gaussians.training_setup(opt) # 为高斯模型参数设置优化器和学习率调度器
# 如果提供了checkpoint则从checkpoint加载模型参数并恢复训练进度
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
# 设置背景颜色,白色或黑色取决于数据集要求
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
# 创建CUDA事件用于计时
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)
viewpoint_stack = None
ema_loss_for_log = 0.0
# 使用tqdm库创建进度条追踪训练进度
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
if network_gui.conn == None:
network_gui.try_connect()
while network_gui.conn != None:
try:
net_image_bytes = None
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
if custom_cam != None:
net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
network_gui.send(net_image_bytes, dataset.source_path)
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
break
except Exception as e:
network_gui.conn = None
iter_start.record() # 记录迭代开始时间
gaussians.update_learning_rate(iteration) # 根据当前迭代次数更新学习率
# 每1000次迭代提升球谐函数的次数以改进模型复杂度Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
gaussians.oneupSHdegree()
# 随机选择一个训练用的相机视角Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
# 如果达到调试起始点,启用调试模式
if (iteration - 1) == debug_from:
pipe.debug = True
# Render根据相机参数使用可微光栅化器渲染图像
# 根据设置决定是否使用随机背景颜色
bg = torch.rand((3), device="cuda") if opt.random_background else background
# 渲染当前视角的图像
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()
iter_end.record() # 记录迭代结束时间
with torch.no_grad():
# 更新进度条和损失显示Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()
# 定期记录训练数据并保存模型Log and save
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)
# 在指定迭代区间内对3D高斯模型进行增密和修剪Densification
if iteration < opt.densify_until_iter:
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()
# 执行优化器的一步并准备下一次迭代Optimizer step
if iteration < opt.iterations:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)
# 定期保存checkpoint
if (iteration in checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
def prepare_output_and_logger(args):
# 没有预设模型输出路径,则随机生成一个文件名,存储输出结果
if not args.model_path:
if os.getenv('OAR_JOB_ID'):
unique_str=os.getenv('OAR_JOB_ID')
else:
unique_str = str(uuid.uuid4())
args.model_path = os.path.join("./output/", unique_str[0:10])
# 创建输出文件夹
print("Output folder: {}".format(args.model_path))
os.makedirs(args.model_path, exist_ok = True)
# 保存模型的配置参数
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
cfg_log_f.write(str(Namespace(**vars(args))))
# 创建Tensorboard writer
tb_writer = None
if TENSORBOARD_FOUND:
tb_writer = SummaryWriter(args.model_path) # 创建一个 SummaryWriter 对象,向 TensorBoard 记录训练过程中的各种指标
else:
print("Tensorboard not available: not logging progress")
return tb_writer
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
tb_writer.add_scalar('iter_time', elapsed, iteration)
# Report test and samples of training set
if iteration in testing_iterations:
torch.cuda.empty_cache()
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
for config in validation_configs:
if config['cameras'] and len(config['cameras']) > 0:
l1_test = 0.0
psnr_test = 0.0
for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
if tb_writer and (idx < 5):
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
if iteration == testing_iterations[0]:
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
l1_test += l1_loss(image, gt_image).mean().double()
psnr_test += psnr(image, gt_image).mean().double()
psnr_test /= len(config['cameras'])
l1_test /= len(config['cameras'])
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
if tb_writer:
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
if tb_writer:
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
torch.cuda.empty_cache()
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
model_prams = ModelParams(parser) # 定义存储 模型 相关参数的arg对象
optim_prams = OptimizationParams(parser) # 定义存储 优化 相关参数的arg对象
pipeline_prams = PipelineParams(parser) # 定义存储 渲染 相关参数的arg对象
parser.add_argument('--ip', type=str, default="127.0.0.1")
parser.add_argument('--port', type=int, default=6009)
parser.add_argument('--debug_from', type=int, default=-1) # 指定从哪一迭代(>= 0开始debug
parser.add_argument('--detect_anomaly', action='store_true', default=False) # action='store_true' 如果命令行中包含了这个参数,它的值将被设置为 True
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[30_000])
parser.add_argument("--start_checkpoint", type=str, default = None)
# 从命令行参数中解析出所有的参数值,并与上面设置的参数一起存储到 Namespace 对象中即args
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
print("Optimizing " + args.model_path)
# 初始化系统的随机状态,以确保实验结果可复现 (RNG)
safe_state(args.quiet)
# 启动GUI 服务器, 监听指定的 IP 地址和端口,观察训练进度和调试问题
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly) # 设置pytorch是否检测梯度异常
# model_prams.extract(args)将args中的属性即命令行和预设的参数中 与 ModelParams类中定义的参数相匹配的值并将它们封装到一个新的 GroupParams 对象中
training(model_prams.extract(args), optim_prams.extract(args), pipeline_prams.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
# All done
print("\nTraining complete.")