add comments in train.py

This commit is contained in:
liuzhi 2024-05-12 19:07:00 +08:00
parent a0534cfb64
commit bfa79e356a
4 changed files with 44 additions and 18 deletions

View File

@ -39,7 +39,9 @@ class ParamGroup:
def extract(self, args):
group = GroupParams()
for arg in vars(args).items():
# 遍历各参数
for arg in vars(args).items(): # 例arg = ('sh_degree', '3')
# 若每个参数的名称与ModelParams等类的属性名称相匹配则将该参数设置到新建的 GroupParams 对象的对应属性上
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
setattr(group, arg[0], arg[1])
return group
@ -57,8 +59,12 @@ class ModelParams(ParamGroup):
super().__init__(parser, "Loading Parameters", sentinel)
def extract(self, args):
g = super().extract(args)
g.source_path = os.path.abspath(g.source_path)
'''
从args对象中提取出与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中
args: 存储着 命令行和main中预设的参数
'''
g = super().extract(args) # 返回的GroupParams对象
g.source_path = os.path.abspath(g.source_path) # 更新为绝对路径
return g
class PipelineParams(ParamGroup):

View File

@ -20,7 +20,7 @@ python train.py --source_path ../../Dataset/3DGS_Dataset/xiangjiadang --model_pa
--convert_cov3D_python添加此标志以使用 PyTorch 而不是论文提出的pipeline计算 3D协方差 的forward and backward
--debug如果遇到错误请启用调试模式。如果光栅化器失败dump则会创建一个文件您可以在问题中将其转发给我们以便我们查看。
--debug_from调试速度慢。可以指定一个迭代次数(从 0 开始),指定数字之前的迭代会是活动状态
--debug_from调试速度慢。可以指定从哪一迭代(>= 0开始
--iterations训练的总迭代次数默认为 30_000

View File

@ -24,6 +24,8 @@ from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
# 尝试导入 PyTorch 提供的 TensorBoard 记录器 SummaryWriter 类
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
@ -31,7 +33,15 @@ except ImportError:
TENSORBOARD_FOUND = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
'''
dataset: 只存储与Moedl相关参数的args
opt: 优化相关参数
pipe: 管道相关参数
checkpoint: 已训练模型的路径
debug_from: 从哪一个迭代开始debug
'''
first_iter = 0
# 创建保存结果的文件夹并保存模型相关的参数到cfg_args文件尝试创建tensorboard_writer记录训练过程
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree) # 创建初始化高斯模型用于表示场景中的每个点的3D高斯分布
@ -144,7 +154,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
def prepare_output_and_logger(args):
def prepare_output_and_logger(args):
# 没有预设模型输出路径,则随机生成一个文件名,存储输出结果
if not args.model_path:
if os.getenv('OAR_JOB_ID'):
unique_str=os.getenv('OAR_JOB_ID')
@ -152,16 +163,17 @@ def prepare_output_and_logger(args):
unique_str = str(uuid.uuid4())
args.model_path = os.path.join("./output/", unique_str[0:10])
# Set up output folder
# 创建输出文件夹
print("Output folder: {}".format(args.model_path))
os.makedirs(args.model_path, exist_ok = True)
# 保存模型的配置参数
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
cfg_log_f.write(str(Namespace(**vars(args))))
# Create Tensorboard writer
# 创建Tensorboard writer
tb_writer = None
if TENSORBOARD_FOUND:
tb_writer = SummaryWriter(args.model_path)
tb_writer = SummaryWriter(args.model_path) # 创建一个 SummaryWriter 对象,向 TensorBoard 记录训练过程中的各种指标
else:
print("Tensorboard not available: not logging progress")
return tb_writer
@ -207,31 +219,37 @@ if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser) # 模型 相关参数
op = OptimizationParams(parser) # 优化 相关参数
pp = PipelineParams(parser) # 渲染 相关参数
model_prams = ModelParams(parser) # 定义存储 模型 相关参数的arg对象
optim_prams = OptimizationParams(parser) # 定义存储 优化 相关参数的arg对象
pipeline_prams = PipelineParams(parser) # 定义存储 渲染 相关参数的arg对象
parser.add_argument('--ip', type=str, default="127.0.0.1")
parser.add_argument('--port', type=int, default=6009)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument('--debug_from', type=int, default=-1) # 指定从哪一迭代(>= 0开始debug
parser.add_argument('--detect_anomaly', action='store_true', default=False) # action='store_true' 如果命令行中包含了这个参数,它的值将被设置为 True
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
# 从命令行参数中解析出所有的参数值,并与上面的设置的参数一起存储到 Namespace 对象中即args
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
print("Optimizing " + args.model_path)
# Initialize system state (RNG)
# 初始化系统的随机状态,以确保实验结果可复现 (RNG)
safe_state(args.quiet)
# Start GUI server, configure and run training
# 启动GUI 服务器, 监听指定的 IP 地址和端口,观察训练进度和调试问题
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
torch.autograd.set_detect_anomaly(args.detect_anomaly) # 设置pytorch是否检测梯度异常
# model_prams.extract(args)将args中的属性即命令行和预设的参数中 与 ModelParams类中定义的参数相匹配的值并将它们封装到一个新的 GroupParams 对象中
training(model_prams.extract(args), optim_prams.extract(args), pipeline_prams.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
# All done
print("\nTraining complete.")

View File

@ -161,9 +161,11 @@ def safe_state(silent):
def flush(self):
old_f.flush()
# 若args.quiet 为 True不写入任何文本到标准输出管道
sys.stdout = F(silent)
# 设置随机种子,使得结果可复现
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
torch.cuda.set_device(torch.device("cuda:0")) # torch 默认的 CUDA 设备为 cuda:0