mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-03-31 15:50:00 +00:00
119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
#
|
||
# Copyright (C) 2023, Inria
|
||
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
||
# All rights reserved.
|
||
#
|
||
# This software is free for non-commercial, research and evaluation use
|
||
# under the terms of the LICENSE.md file.
|
||
#
|
||
# For inquiries contact george.drettakis@inria.fr
|
||
#
|
||
|
||
from argparse import ArgumentParser, Namespace
|
||
import sys
|
||
import os
|
||
|
||
class GroupParams:
|
||
pass
|
||
|
||
class ParamGroup:
|
||
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
|
||
group = parser.add_argument_group(name)
|
||
for key, value in vars(self).items():
|
||
shorthand = False
|
||
if key.startswith("_"):
|
||
shorthand = True
|
||
key = key[1:]
|
||
t = type(value)
|
||
value = value if not fill_none else None
|
||
if shorthand:
|
||
if t == bool:
|
||
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
|
||
else:
|
||
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
|
||
else:
|
||
if t == bool:
|
||
group.add_argument("--" + key, default=value, action="store_true")
|
||
else:
|
||
group.add_argument("--" + key, default=value, type=t)
|
||
|
||
def extract(self, args):
|
||
group = GroupParams()
|
||
# 遍历各参数
|
||
for arg in vars(args).items(): # 例arg = ('sh_degree', '3')
|
||
# 若每个参数的名称与ModelParams等类的属性名称相匹配,则将该参数设置到新建的 GroupParams 对象的对应属性上
|
||
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
|
||
setattr(group, arg[0], arg[1])
|
||
return group
|
||
|
||
class ModelParams(ParamGroup):
|
||
def __init__(self, parser, sentinel=False):
|
||
self.sh_degree = 3
|
||
self._source_path = ""
|
||
self._model_path = ""
|
||
self._images = "images"
|
||
self._resolution = -1
|
||
self._white_background = False
|
||
self.data_device = "cuda"
|
||
self.eval = False
|
||
super().__init__(parser, "Loading Parameters", sentinel)
|
||
|
||
def extract(self, args):
|
||
'''
|
||
从args对象中提取出与 ModelParams类中定义的参数相匹配的值,并将它们封装到一个新的 GroupParams 对象中
|
||
args: 存储着 命令行和main中预设的参数
|
||
'''
|
||
g = super().extract(args) # 返回的GroupParams对象
|
||
g.source_path = os.path.abspath(g.source_path) # 更新为绝对路径
|
||
return g
|
||
|
||
class PipelineParams(ParamGroup):
|
||
def __init__(self, parser):
|
||
self.convert_SHs_python = False
|
||
self.compute_cov3D_python = False
|
||
self.debug = False
|
||
super().__init__(parser, "Pipeline Parameters")
|
||
|
||
class OptimizationParams(ParamGroup):
|
||
def __init__(self, parser):
|
||
self.iterations = 30_000
|
||
self.position_lr_init = 0.00016
|
||
self.position_lr_final = 0.0000016
|
||
self.position_lr_delay_mult = 0.01
|
||
self.position_lr_max_steps = 30_000
|
||
self.feature_lr = 0.0025
|
||
self.opacity_lr = 0.05
|
||
self.scaling_lr = 0.005
|
||
self.rotation_lr = 0.001
|
||
self.percent_dense = 0.01
|
||
self.lambda_dssim = 0.2
|
||
self.densification_interval = 100
|
||
self.opacity_reset_interval = 3000
|
||
self.densify_from_iter = 500
|
||
self.densify_until_iter = 15_000
|
||
self.densify_grad_threshold = 0.0002
|
||
self.random_background = False
|
||
super().__init__(parser, "Optimization Parameters")
|
||
|
||
def get_combined_args(parser : ArgumentParser):
|
||
cmdlne_string = sys.argv[1:]
|
||
cfgfile_string = "Namespace()"
|
||
args_cmdline = parser.parse_args(cmdlne_string)
|
||
|
||
try:
|
||
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
|
||
print("Looking for config file in", cfgfilepath)
|
||
with open(cfgfilepath) as cfg_file:
|
||
print("Config file found: {}".format(cfgfilepath))
|
||
cfgfile_string = cfg_file.read()
|
||
except TypeError:
|
||
print("Config file not found at")
|
||
pass
|
||
args_cfgfile = eval(cfgfile_string)
|
||
|
||
merged_dict = vars(args_cfgfile).copy()
|
||
for k,v in vars(args_cmdline).items():
|
||
if v != None:
|
||
merged_dict[k] = v
|
||
return Namespace(**merged_dict)
|