mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-05-01 18:58:16 +00:00
add comments
This commit is contained in:
parent
268f92fafc
commit
b405eb02e7
@ -26,82 +26,77 @@ class Scene:
|
||||
|
||||
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
|
||||
"""
|
||||
初始化3D场景对象
|
||||
|
||||
初始化3D场景对象
|
||||
args: 存储着与 GaussianMoedl 相关参数 的args,即包含scene/__init__.py/ModelParams()中的参数
|
||||
gaussians: 3D高斯模型对象,用于场景点的3D表示
|
||||
|
||||
load_iteration: 指定加载模型的迭代次数,如果是-1,则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型;如果不是None且不是-1,则加载指定迭代次数的
|
||||
shuffle: 是否在训练前打乱相机列表
|
||||
load_iteration: 训练时为 None,render时为 指定的iteration:如果是-1,则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型;如果不是None且不是-1,则加载指定迭代次数的
|
||||
shuffle: 是否在训练前打乱相机列表(render.py中会设为False)
|
||||
resolution_scales: 分辨率比例列表,用于处理不同分辨率的相机
|
||||
"""
|
||||
self.model_path = args.model_path # 模型文件保存路径
|
||||
self.loaded_iter = None # 已加载的迭代次数
|
||||
self.gaussians = gaussians # 高斯模型对象
|
||||
|
||||
# 如果已有训练模型,则加载
|
||||
# 1. 如果指定了加载模型的迭代次数,则赋给Scene.loaded_iter
|
||||
if load_iteration:
|
||||
if load_iteration == -1:
|
||||
# 是-1,则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
|
||||
if load_iteration == -1: # 是-1,则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
|
||||
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
|
||||
else:
|
||||
# 不是None且不是-1,则加载指定迭代次数的
|
||||
else: # 不是None且不是-1,则加载指定迭代次数的
|
||||
self.loaded_iter = load_iteration
|
||||
print("Loading trained model at iteration {}".format(self.loaded_iter))
|
||||
|
||||
self.train_cameras = {} # 用于训练的相机
|
||||
self.test_cameras = {} # 用于测试的相机
|
||||
|
||||
# 从COLMAP或Blender的输出结果中构建 场景信息(包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径)
|
||||
# 2. 从COLMAP或Blender的输出结果中构建 scene_info(包含 点云、train相机info、test相机info、场景归一化参数、点云文件路径)
|
||||
if os.path.exists(os.path.join(args.source_path, "sparse")):
|
||||
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
|
||||
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
|
||||
print("Found transforms_train.json file, assuming Blender data set!")
|
||||
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
|
||||
else:
|
||||
assert False, "Could not recognize scene type!"
|
||||
assert False, "Could not recognize scene type! {}".format(os.path.join(args.source_path, "sparse"))
|
||||
|
||||
if not self.loaded_iter:
|
||||
# 如果没有加载模型,则将点云文件point3D.ply文件复制到input.ply文件
|
||||
# 未加载已训练模型,则:
|
||||
# (1) 将点云文件point3D.ply文件复制到input.ply文件
|
||||
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
|
||||
dest_file.write(src_file.read())
|
||||
|
||||
# (2) 将相机参数写入cameras.json文件
|
||||
json_cams = []
|
||||
camlist = []
|
||||
if scene_info.test_cameras:
|
||||
# 测试相机添加到 camlist 中
|
||||
if scene_info.test_cameras: # 添加test相机到 camlist 中
|
||||
camlist.extend(scene_info.test_cameras)
|
||||
if scene_info.train_cameras:
|
||||
# 训练相机添加到 camlist 中
|
||||
if scene_info.train_cameras: # 添加train相机到 camlist 中
|
||||
camlist.extend(scene_info.train_cameras)
|
||||
# 遍历 camlist 中的所有相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
|
||||
|
||||
# 遍历 camlist 中的所有训练和测试相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
|
||||
for id, cam in enumerate(camlist):
|
||||
json_cams.append(camera_to_JSON(id, cam))
|
||||
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
|
||||
json.dump(json_cams, file)
|
||||
|
||||
if shuffle:
|
||||
# 随机打乱训练和测试相机列表
|
||||
# 3. 随机打乱train、test相机info
|
||||
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
|
||||
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
|
||||
|
||||
self.cameras_extent = scene_info.nerf_normalization["radius"]
|
||||
|
||||
# 根据resolution_scales加载不同分辨率的训练和测试相机(包含R、T、视场角)
|
||||
# 4. 调整图片分辨率,并根据train、test相机info(包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)创建相机
|
||||
for resolution_scale in resolution_scales:
|
||||
print("Loading Training Cameras")
|
||||
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
|
||||
print("Loading Test Cameras")
|
||||
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
|
||||
|
||||
# 5. 加载高斯体
|
||||
if self.loaded_iter:
|
||||
# 如果加载已训练模型,则直接读取对应(已经迭代出来的)场景
|
||||
self.gaussians.load_ply(os.path.join(self.model_path,
|
||||
"point_cloud",
|
||||
"iteration_" + str(self.loaded_iter),
|
||||
"point_cloud.ply"))
|
||||
# 如果加载已训练模型,则直接读取其已训练的高斯体
|
||||
self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))
|
||||
else:
|
||||
# 不加载训练模型,则调用 GaussianModel.create_from_pcd 从稀疏点云 scene_info.point_cloud 中建立模型
|
||||
# 不加载,则从稀疏点云中初始建立高斯体
|
||||
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
||||
|
||||
def save(self, iteration):
|
||||
|
@ -96,8 +96,8 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
||||
if intr.model=="SIMPLE_PINHOLE" or intr.model=="SIMPLE_RADIAL":
|
||||
# 如果是简单针孔模型,只有一个焦距参数
|
||||
focal_length_x = intr.params[0]
|
||||
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角
|
||||
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角
|
||||
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角: 2 * arctan(H / 2fy))
|
||||
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角: 2 * arctan(W / 2fx)
|
||||
elif intr.model=="PINHOLE":
|
||||
# 如果是针孔模型,有两个焦距参数
|
||||
focal_length_x = intr.params[0]
|
||||
@ -115,7 +115,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
||||
continue
|
||||
image = Image.open(image_path) # PIL.Image读取的为RGB格式,OpenCV读取的为BGR格式
|
||||
|
||||
# 创建相机信息类CameraInfo对象 (包含旋转矩阵、平移向量、视场角、图像数据、图片路径、图片名、宽度、高度),并添加到列表cam_infos中
|
||||
# 创建相机信息类CameraInfo对象 (包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height),并添加到列表cam_infos中
|
||||
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
||||
image_path=image_path, image_name=image_name, width=width, height=height)
|
||||
cam_infos.append(cam_info)
|
||||
@ -155,7 +155,7 @@ def storePly(path, xyz, rgb):
|
||||
def readColmapSceneInfo(path, images, eval, llffhold=8):
|
||||
'''
|
||||
加载COLMAP的结果中的二进制相机外参文件imags.bin 和 内参文件cameras.bin
|
||||
path: GaussianModel中的源文件路径
|
||||
path: source_path
|
||||
images: 'images'
|
||||
eval: 是否为eval模式
|
||||
llffhold: 默认为8
|
||||
@ -181,20 +181,20 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
|
||||
# 根据图片名称排序,以保证顺序一致性
|
||||
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : (x.image_path.split('/')[-2], int(x.image_name)))
|
||||
|
||||
# 根据是否为评估模式(eval),将相机分为训练集和测试集
|
||||
# 如果为评估模式,每llffhold张图片取一张作为测试集
|
||||
# 根据是否eval,将相机分为训练集和测试集
|
||||
if eval:
|
||||
# 若要评测,则每llffhold张图片取一张作为测试集
|
||||
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
|
||||
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
|
||||
else:
|
||||
# 如果不是评估模式,所有相机均为训练相机,测试相机列表为空
|
||||
# 不评测,则所有相机均为训练相机,测试相机列表为空
|
||||
train_cam_infos = cam_infos
|
||||
test_cam_infos = []
|
||||
|
||||
# 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定
|
||||
nerf_normalization = getNerfppNorm(train_cam_infos)
|
||||
|
||||
# 尝试读取COLMAP生成的稀疏点云数据,优先从PLY文件读取,如果不存在,则尝试从BIN或TXT文件转换并保存为PLY格式
|
||||
# 读取COLMAP生成的稀疏点云数据,优先从PLY文件读取,如果不存在,则尝试从BIN或TXT文件转换并保存为PLY格式
|
||||
ply_path = os.path.join(path, "sparse/0/points3D.ply")
|
||||
bin_path = os.path.join(path, "sparse/0/points3D.bin")
|
||||
txt_path = os.path.join(path, "sparse/0/points3D.txt")
|
||||
@ -206,10 +206,9 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
|
||||
xyz, rgb, _ = read_points3D_text(txt_path)
|
||||
|
||||
storePly(ply_path, xyz, rgb) # 转换成ply文件
|
||||
|
||||
# 读取PLY格式的稀疏点云
|
||||
try:
|
||||
pcd = fetchPly(ply_path) # points3D.ply读取COLMAP产生的稀疏点云
|
||||
|
||||
pcd = fetchPly(ply_path)
|
||||
except:
|
||||
pcd = None
|
||||
|
||||
|
@ -17,12 +17,17 @@ from utils.graphics_utils import fov2focal
|
||||
WARNED = False
|
||||
|
||||
def loadCam(args, id, cam_info, resolution_scale):
|
||||
"""
|
||||
调整当前相机对应图像的分辨率,并根据当前相机的info创建相机(包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)
|
||||
"""
|
||||
orig_w, orig_h = cam_info.image.size
|
||||
|
||||
# 1. 计算下采样后的图像尺寸
|
||||
if args.resolution in [1, 2, 4, 8]:
|
||||
# 计算下采样后的图像尺寸 [1, 1/2, 1/4, 1/8]
|
||||
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
||||
else: # should be a type that converts to float
|
||||
else:
|
||||
if args.resolution == -1:
|
||||
# 如果用户没有指定分辨率,即默认为-1,则自动判断图片的宽度是>1.6K:如果大于,则自动进行下采样到1.6K时的采样倍率;如果小于,则采样倍率=1,即使用原图尺寸
|
||||
if orig_w > 1600:
|
||||
global WARNED
|
||||
if not WARNED:
|
||||
@ -33,19 +38,22 @@ def loadCam(args, id, cam_info, resolution_scale):
|
||||
else:
|
||||
global_down = 1
|
||||
else:
|
||||
# 如果用户指定了分辨率,则根据用户指定的分辨率计算采样倍率
|
||||
global_down = orig_w / args.resolution
|
||||
|
||||
scale = float(global_down) * float(resolution_scale)
|
||||
resolution = (int(orig_w / scale), int(orig_h / scale))
|
||||
scale = float(global_down) * float(resolution_scale) # 缩放倍率
|
||||
resolution = (int(orig_w / scale), int(orig_h / scale)) # 下采样后的图像尺寸
|
||||
|
||||
resized_image_rgb = PILtoTorch(cam_info.image, resolution) # 调整图片比例,归一化,并转换通道为torch上的 (C, H, W)
|
||||
# 2. 调整图片分辨率,归一化,并转换通道为torch上的 (C, H, W)
|
||||
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
||||
|
||||
gt_image = resized_image_rgb[:3, ...]
|
||||
loaded_mask = None
|
||||
|
||||
if resized_image_rgb.shape[1] == 4:
|
||||
# 如果图片有alpha通道,则提取出来
|
||||
loaded_mask = resized_image_rgb[3:4, ...]
|
||||
|
||||
# 3. 创建相机
|
||||
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||
@ -53,14 +61,14 @@ def loadCam(args, id, cam_info, resolution_scale):
|
||||
|
||||
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
||||
'''
|
||||
cam_infos: 训练或测试相机对象列表
|
||||
resolution_scale: 不同分辨率列表
|
||||
args: 高斯模型参数
|
||||
cam_infos: train或test相机info列表
|
||||
resolution_scale: 分辨率倍率
|
||||
args: 更新后的ModelParams()中的参数
|
||||
'''
|
||||
camera_list = []
|
||||
|
||||
# 遍历每个camera_info(包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)
|
||||
for id, c in enumerate(cam_infos):
|
||||
camera_list.append(loadCam(args, id, c, resolution_scale))
|
||||
camera_list.append( loadCam(args, id, c, resolution_scale) )
|
||||
|
||||
return camera_list
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user