diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a55..4238f73 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -39,7 +39,9 @@ class ParamGroup: def extract(self, args): group = GroupParams() - for arg in vars(args).items(): + # 遍历各参数 + for arg in vars(args).items(): # 例arg = ('sh_degree', '3') + # 若每个参数的名称与ModelParams等类的属性名称相匹配,则将该参数设置到新建的 GroupParams 对象的对应属性上 if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): setattr(group, arg[0], arg[1]) return group @@ -57,8 +59,12 @@ class ModelParams(ParamGroup): super().__init__(parser, "Loading Parameters", sentinel) def extract(self, args): - g = super().extract(args) - g.source_path = os.path.abspath(g.source_path) + ''' + 从args对象中提取出与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中 + args: 存储着 命令行和main中预设的参数 + ''' + g = super().extract(args) # 返回的GroupParams对象 + g.source_path = os.path.abspath(g.source_path) # 更新为绝对路径 return g class PipelineParams(ParamGroup): diff --git a/run_code.txt b/run_code.txt index 9d47031..d8fc40c 100644 --- a/run_code.txt +++ b/run_code.txt @@ -20,7 +20,7 @@ python train.py --source_path ../../Dataset/3DGS_Dataset/xiangjiadang --model_pa --convert_cov3D_python:添加此标志以使用 PyTorch 而不是论文提出的pipeline计算 3D协方差 的forward and backward --debug:如果遇到错误,请启用调试模式。如果光栅化器失败,dump则会创建一个文件,您可以在问题中将其转发给我们,以便我们查看。 ---debug_from:调试速度慢。可以指定一个迭代次数(从 0 开始),指定数字之前的迭代会是活动状态 +--debug_from:调试速度慢。可以指定从哪一迭代(>= 0)开始 --iterations:训练的总迭代次数,默认为 30_000 diff --git a/train.py b/train.py index b9d34e6..f619188 100644 --- a/train.py +++ b/train.py @@ -24,6 +24,8 @@ from tqdm import tqdm from utils.image_utils import psnr from argparse import ArgumentParser, Namespace from arguments import ModelParams, PipelineParams, OptimizationParams + +# 尝试导入 PyTorch 提供的 TensorBoard 记录器 SummaryWriter 类 try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_FOUND = True @@ -31,7 +33,15 @@ except ImportError: TENSORBOARD_FOUND = False def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): + ''' + dataset: 只存储与Moedl相关参数的args + opt: 优化相关参数 + pipe: 管道相关参数 + checkpoint: 已训练模型的路径 + debug_from: 从哪一个迭代开始debug + ''' first_iter = 0 + # 创建保存结果的文件夹,并保存模型相关的参数到cfg_args文件;尝试创建tensorboard_writer,记录训练过程 tb_writer = prepare_output_and_logger(dataset) gaussians = GaussianModel(dataset.sh_degree) # 创建初始化高斯模型,用于表示场景中的每个点的3D高斯分布 @@ -144,7 +154,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi print("\n[ITER {}] Saving Checkpoint".format(iteration)) torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") -def prepare_output_and_logger(args): +def prepare_output_and_logger(args): + # 没有预设模型输出路径,则随机生成一个文件名,存储输出结果 if not args.model_path: if os.getenv('OAR_JOB_ID'): unique_str=os.getenv('OAR_JOB_ID') @@ -152,16 +163,17 @@ def prepare_output_and_logger(args): unique_str = str(uuid.uuid4()) args.model_path = os.path.join("./output/", unique_str[0:10]) - # Set up output folder + # 创建输出文件夹 print("Output folder: {}".format(args.model_path)) os.makedirs(args.model_path, exist_ok = True) + # 保存模型的配置参数 with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: cfg_log_f.write(str(Namespace(**vars(args)))) - # Create Tensorboard writer + # 创建Tensorboard writer tb_writer = None if TENSORBOARD_FOUND: - tb_writer = SummaryWriter(args.model_path) + tb_writer = SummaryWriter(args.model_path) # 创建一个 SummaryWriter 对象,向 TensorBoard 记录训练过程中的各种指标 else: print("Tensorboard not available: not logging progress") return tb_writer @@ -207,31 +219,37 @@ if __name__ == "__main__": # Set up command line argument parser parser = ArgumentParser(description="Training script parameters") - lp = ModelParams(parser) # 模型 相关参数 - op = OptimizationParams(parser) # 优化 相关参数 - pp = PipelineParams(parser) # 渲染 相关参数 + model_prams = ModelParams(parser) # 定义存储 模型 相关参数的arg对象 + optim_prams = OptimizationParams(parser) # 定义存储 优化 相关参数的arg对象 + pipeline_prams = PipelineParams(parser) # 定义存储 渲染 相关参数的arg对象 parser.add_argument('--ip', type=str, default="127.0.0.1") parser.add_argument('--port', type=int, default=6009) - parser.add_argument('--debug_from', type=int, default=-1) - parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument('--debug_from', type=int, default=-1) # 指定从哪一迭代(>= 0)开始debug + parser.add_argument('--detect_anomaly', action='store_true', default=False) # action='store_true' 如果命令行中包含了这个参数,它的值将被设置为 True parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default = None) + + # 从命令行参数中解析出所有的参数值,并与上面的设置的参数一起存储到 Namespace 对象中,即args args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) print("Optimizing " + args.model_path) - # Initialize system state (RNG) + # 初始化系统的随机状态,以确保实验结果可复现 (RNG) safe_state(args.quiet) - # Start GUI server, configure and run training + # 启动GUI 服务器, 监听指定的 IP 地址和端口,观察训练进度和调试问题 network_gui.init(args.ip, args.port) - torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + + torch.autograd.set_detect_anomaly(args.detect_anomaly) # 设置pytorch是否检测梯度异常 + + # model_prams.extract(args):将args中的属性,即命令行和预设的参数中 与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中 + training(model_prams.extract(args), optim_prams.extract(args), pipeline_prams.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done print("\nTraining complete.") diff --git a/utils/general_utils.py b/utils/general_utils.py index 5bdfa4c..37231fe 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -161,9 +161,11 @@ def safe_state(silent): def flush(self): old_f.flush() + # 若args.quiet 为 True,不写入任何文本到标准输出管道 sys.stdout = F(silent) + # 设置随机种子,使得结果可复现 random.seed(0) np.random.seed(0) torch.manual_seed(0) - torch.cuda.set_device(torch.device("cuda:0")) + torch.cuda.set_device(torch.device("cuda:0")) # torch 默认的 CUDA 设备为 cuda:0 \ No newline at end of file