mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-05-04 20:21:46 +00:00
add comments in train.py
This commit is contained in:
parent
a0534cfb64
commit
bfa79e356a
@ -39,7 +39,9 @@ class ParamGroup:
|
||||
|
||||
def extract(self, args):
|
||||
group = GroupParams()
|
||||
for arg in vars(args).items():
|
||||
# 遍历各参数
|
||||
for arg in vars(args).items(): # 例arg = ('sh_degree', '3')
|
||||
# 若每个参数的名称与ModelParams等类的属性名称相匹配,则将该参数设置到新建的 GroupParams 对象的对应属性上
|
||||
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
|
||||
setattr(group, arg[0], arg[1])
|
||||
return group
|
||||
@ -57,8 +59,12 @@ class ModelParams(ParamGroup):
|
||||
super().__init__(parser, "Loading Parameters", sentinel)
|
||||
|
||||
def extract(self, args):
|
||||
g = super().extract(args)
|
||||
g.source_path = os.path.abspath(g.source_path)
|
||||
'''
|
||||
从args对象中提取出与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中
|
||||
args: 存储着 命令行和main中预设的参数
|
||||
'''
|
||||
g = super().extract(args) # 返回的GroupParams对象
|
||||
g.source_path = os.path.abspath(g.source_path) # 更新为绝对路径
|
||||
return g
|
||||
|
||||
class PipelineParams(ParamGroup):
|
||||
|
@ -20,7 +20,7 @@ python train.py --source_path ../../Dataset/3DGS_Dataset/xiangjiadang --model_pa
|
||||
--convert_cov3D_python:添加此标志以使用 PyTorch 而不是论文提出的pipeline计算 3D协方差 的forward and backward
|
||||
|
||||
--debug:如果遇到错误,请启用调试模式。如果光栅化器失败,dump则会创建一个文件,您可以在问题中将其转发给我们,以便我们查看。
|
||||
--debug_from:调试速度慢。可以指定一个迭代次数(从 0 开始),指定数字之前的迭代会是活动状态
|
||||
--debug_from:调试速度慢。可以指定从哪一迭代(>= 0)开始
|
||||
|
||||
--iterations:训练的总迭代次数,默认为 30_000
|
||||
|
||||
|
44
train.py
44
train.py
@ -24,6 +24,8 @@ from tqdm import tqdm
|
||||
from utils.image_utils import psnr
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from arguments import ModelParams, PipelineParams, OptimizationParams
|
||||
|
||||
# 尝试导入 PyTorch 提供的 TensorBoard 记录器 SummaryWriter 类
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
TENSORBOARD_FOUND = True
|
||||
@ -31,7 +33,15 @@ except ImportError:
|
||||
TENSORBOARD_FOUND = False
|
||||
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
||||
'''
|
||||
dataset: 只存储与Moedl相关参数的args
|
||||
opt: 优化相关参数
|
||||
pipe: 管道相关参数
|
||||
checkpoint: 已训练模型的路径
|
||||
debug_from: 从哪一个迭代开始debug
|
||||
'''
|
||||
first_iter = 0
|
||||
# 创建保存结果的文件夹,并保存模型相关的参数到cfg_args文件;尝试创建tensorboard_writer,记录训练过程
|
||||
tb_writer = prepare_output_and_logger(dataset)
|
||||
|
||||
gaussians = GaussianModel(dataset.sh_degree) # 创建初始化高斯模型,用于表示场景中的每个点的3D高斯分布
|
||||
@ -144,7 +154,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
||||
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
||||
|
||||
def prepare_output_and_logger(args):
|
||||
def prepare_output_and_logger(args):
|
||||
# 没有预设模型输出路径,则随机生成一个文件名,存储输出结果
|
||||
if not args.model_path:
|
||||
if os.getenv('OAR_JOB_ID'):
|
||||
unique_str=os.getenv('OAR_JOB_ID')
|
||||
@ -152,16 +163,17 @@ def prepare_output_and_logger(args):
|
||||
unique_str = str(uuid.uuid4())
|
||||
args.model_path = os.path.join("./output/", unique_str[0:10])
|
||||
|
||||
# Set up output folder
|
||||
# 创建输出文件夹
|
||||
print("Output folder: {}".format(args.model_path))
|
||||
os.makedirs(args.model_path, exist_ok = True)
|
||||
# 保存模型的配置参数
|
||||
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
|
||||
cfg_log_f.write(str(Namespace(**vars(args))))
|
||||
|
||||
# Create Tensorboard writer
|
||||
# 创建Tensorboard writer
|
||||
tb_writer = None
|
||||
if TENSORBOARD_FOUND:
|
||||
tb_writer = SummaryWriter(args.model_path)
|
||||
tb_writer = SummaryWriter(args.model_path) # 创建一个 SummaryWriter 对象,向 TensorBoard 记录训练过程中的各种指标
|
||||
else:
|
||||
print("Tensorboard not available: not logging progress")
|
||||
return tb_writer
|
||||
@ -207,31 +219,37 @@ if __name__ == "__main__":
|
||||
# Set up command line argument parser
|
||||
parser = ArgumentParser(description="Training script parameters")
|
||||
|
||||
lp = ModelParams(parser) # 模型 相关参数
|
||||
op = OptimizationParams(parser) # 优化 相关参数
|
||||
pp = PipelineParams(parser) # 渲染 相关参数
|
||||
model_prams = ModelParams(parser) # 定义存储 模型 相关参数的arg对象
|
||||
optim_prams = OptimizationParams(parser) # 定义存储 优化 相关参数的arg对象
|
||||
pipeline_prams = PipelineParams(parser) # 定义存储 渲染 相关参数的arg对象
|
||||
|
||||
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
||||
parser.add_argument('--port', type=int, default=6009)
|
||||
parser.add_argument('--debug_from', type=int, default=-1)
|
||||
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
||||
parser.add_argument('--debug_from', type=int, default=-1) # 指定从哪一迭代(>= 0)开始debug
|
||||
parser.add_argument('--detect_anomaly', action='store_true', default=False) # action='store_true' 如果命令行中包含了这个参数,它的值将被设置为 True
|
||||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--quiet", action="store_true")
|
||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||
|
||||
# 从命令行参数中解析出所有的参数值,并与上面的设置的参数一起存储到 Namespace 对象中,即args
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
args.save_iterations.append(args.iterations)
|
||||
|
||||
print("Optimizing " + args.model_path)
|
||||
|
||||
# Initialize system state (RNG)
|
||||
# 初始化系统的随机状态,以确保实验结果可复现 (RNG)
|
||||
safe_state(args.quiet)
|
||||
|
||||
# Start GUI server, configure and run training
|
||||
# 启动GUI 服务器, 监听指定的 IP 地址和端口,观察训练进度和调试问题
|
||||
network_gui.init(args.ip, args.port)
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly) # 设置pytorch是否检测梯度异常
|
||||
|
||||
# model_prams.extract(args):将args中的属性,即命令行和预设的参数中 与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中
|
||||
training(model_prams.extract(args), optim_prams.extract(args), pipeline_prams.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
|
||||
# All done
|
||||
print("\nTraining complete.")
|
||||
|
@ -161,9 +161,11 @@ def safe_state(silent):
|
||||
def flush(self):
|
||||
old_f.flush()
|
||||
|
||||
# 若args.quiet 为 True,不写入任何文本到标准输出管道
|
||||
sys.stdout = F(silent)
|
||||
|
||||
# 设置随机种子,使得结果可复现
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.set_device(torch.device("cuda:0"))
|
||||
torch.cuda.set_device(torch.device("cuda:0")) # torch 默认的 CUDA 设备为 cuda:0
|
Loading…
Reference in New Issue
Block a user