add comments

This commit is contained in:
liuzhi 2024-07-03 20:43:45 +08:00
parent 268f92fafc
commit b405eb02e7
3 changed files with 50 additions and 48 deletions

View File

@ -27,81 +27,76 @@ class Scene:
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
""" """
初始化3D场景对象 初始化3D场景对象
args: 存储着与 GaussianMoedl 相关参数 的args即包含scene/__init__.py/ModelParams()中的参数 args: 存储着与 GaussianMoedl 相关参数 的args即包含scene/__init__.py/ModelParams()中的参数
gaussians: 3D高斯模型对象用于场景点的3D表示 gaussians: 3D高斯模型对象用于场景点的3D表示
load_iteration: 训练时为 Nonerender时为 指定的iteration如果是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型如果不是None且不是-1则加载指定迭代次数的
load_iteration: 指定加载模型的迭代次数如果是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型如果不是None且不是-1则加载指定迭代次数的 shuffle: 是否在训练前打乱相机列表render.py中会设为False
shuffle: 是否在训练前打乱相机列表
resolution_scales: 分辨率比例列表用于处理不同分辨率的相机 resolution_scales: 分辨率比例列表用于处理不同分辨率的相机
""" """
self.model_path = args.model_path # 模型文件保存路径 self.model_path = args.model_path # 模型文件保存路径
self.loaded_iter = None # 已加载的迭代次数 self.loaded_iter = None # 已加载的迭代次数
self.gaussians = gaussians # 高斯模型对象 self.gaussians = gaussians # 高斯模型对象
# 如果已有训练模型,则加载 # 1. 如果指定了加载模型的迭代次数则赋给Scene.loaded_iter
if load_iteration: if load_iteration:
if load_iteration == -1: if load_iteration == -1: # 是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
# 是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else: else: # 不是None且不是-1则加载指定迭代次数的
# 不是None且不是-1则加载指定迭代次数的
self.loaded_iter = load_iteration self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter)) print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {} # 用于训练的相机 self.train_cameras = {} # 用于训练的相机
self.test_cameras = {} # 用于测试的相机 self.test_cameras = {} # 用于测试的相机
# 从COLMAP或Blender的输出结果中构建 场景信息(包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径) # 2. 从COLMAP或Blender的输出结果中构建 scene_info包含 点云、train相机info、test相机info、场景归一化参数、点云文件路径)
if os.path.exists(os.path.join(args.source_path, "sparse")): if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!") print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
else: else:
assert False, "Could not recognize scene type!" assert False, "Could not recognize scene type! {}".format(os.path.join(args.source_path, "sparse"))
if not self.loaded_iter: if not self.loaded_iter:
# 如果没有加载模型则将点云文件point3D.ply文件复制到input.ply文件 # 未加载已训练模型,则:
# (1) 将点云文件point3D.ply文件复制到input.ply文件
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
dest_file.write(src_file.read()) dest_file.write(src_file.read())
# (2) 将相机参数写入cameras.json文件
json_cams = [] json_cams = []
camlist = [] camlist = []
if scene_info.test_cameras: if scene_info.test_cameras: # 添加test相机到 camlist 中
# 测试相机添加到 camlist 中
camlist.extend(scene_info.test_cameras) camlist.extend(scene_info.test_cameras)
if scene_info.train_cameras: if scene_info.train_cameras: # 添加train相机到 camlist 中
# 训练相机添加到 camlist 中
camlist.extend(scene_info.train_cameras) camlist.extend(scene_info.train_cameras)
# 遍历 camlist 中的所有相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
# 遍历 camlist 中的所有训练和测试相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
for id, cam in enumerate(camlist): for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam)) json_cams.append(camera_to_JSON(id, cam))
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
json.dump(json_cams, file) json.dump(json_cams, file)
if shuffle: if shuffle:
# 随机打乱训练和测试相机列表 # 3. 随机打乱train、test相机info
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"] self.cameras_extent = scene_info.nerf_normalization["radius"]
# 根据resolution_scales加载不同分辨率的训练和测试相机包含R、T、视场角 # 4. 调整图片分辨率并根据train、test相机info(包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)创建相机
for resolution_scale in resolution_scales: for resolution_scale in resolution_scales:
print("Loading Training Cameras") print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
print("Loading Test Cameras") print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
# 5. 加载高斯体
if self.loaded_iter: if self.loaded_iter:
# 如果加载已训练模型,则直接读取对应(已经迭代出来的)场景 # 如果加载已训练模型,则直接读取其已训练的高斯体
self.gaussians.load_ply(os.path.join(self.model_path, self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
else: else:
# 不加载训练模型,则调用 GaussianModel.create_from_pcd 从稀疏点云 scene_info.point_cloud 中建立模型 # 不加载,则从稀疏点云中初始建立高斯体
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
def save(self, iteration): def save(self, iteration):

View File

@ -96,8 +96,8 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
if intr.model=="SIMPLE_PINHOLE" or intr.model=="SIMPLE_RADIAL": if intr.model=="SIMPLE_PINHOLE" or intr.model=="SIMPLE_RADIAL":
# 如果是简单针孔模型,只有一个焦距参数 # 如果是简单针孔模型,只有一个焦距参数
focal_length_x = intr.params[0] focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角 FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角: 2 * arctan(H / 2fy))
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角 FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角: 2 * arctan(W / 2fx)
elif intr.model=="PINHOLE": elif intr.model=="PINHOLE":
# 如果是针孔模型,有两个焦距参数 # 如果是针孔模型,有两个焦距参数
focal_length_x = intr.params[0] focal_length_x = intr.params[0]
@ -115,7 +115,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
continue continue
image = Image.open(image_path) # PIL.Image读取的为RGB格式OpenCV读取的为BGR格式 image = Image.open(image_path) # PIL.Image读取的为RGB格式OpenCV读取的为BGR格式
# 创建相机信息类CameraInfo对象 (包含旋转矩阵、平移向量、视场角、图像数据、图片路径、图片名、宽度、高度)并添加到列表cam_infos中 # 创建相机信息类CameraInfo对象 (包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)并添加到列表cam_infos中
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=width, height=height) image_path=image_path, image_name=image_name, width=width, height=height)
cam_infos.append(cam_info) cam_infos.append(cam_info)
@ -155,7 +155,7 @@ def storePly(path, xyz, rgb):
def readColmapSceneInfo(path, images, eval, llffhold=8): def readColmapSceneInfo(path, images, eval, llffhold=8):
''' '''
加载COLMAP的结果中的二进制相机外参文件imags.bin 内参文件cameras.bin 加载COLMAP的结果中的二进制相机外参文件imags.bin 内参文件cameras.bin
path: GaussianModel中的源文件路径 path: source_path
images: 'images' images: 'images'
eval: 是否为eval模式 eval: 是否为eval模式
llffhold: 默认为8 llffhold: 默认为8
@ -181,20 +181,20 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
# 根据图片名称排序,以保证顺序一致性 # 根据图片名称排序,以保证顺序一致性
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : (x.image_path.split('/')[-2], int(x.image_name))) cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : (x.image_path.split('/')[-2], int(x.image_name)))
# 根据是否为评估模式eval将相机分为训练集和测试集 # 根据是否eval将相机分为训练集和测试集
# 如果为评估模式每llffhold张图片取一张作为测试集
if eval: if eval:
# 若要评测则每llffhold张图片取一张作为测试集
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
else: else:
# 如果不是评估模式,所有相机均为训练相机,测试相机列表为空 # 不评测,则所有相机均为训练相机,测试相机列表为空
train_cam_infos = cam_infos train_cam_infos = cam_infos
test_cam_infos = [] test_cam_infos = []
# 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定 # 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定
nerf_normalization = getNerfppNorm(train_cam_infos) nerf_normalization = getNerfppNorm(train_cam_infos)
# 尝试读取COLMAP生成的稀疏点云数据优先从PLY文件读取如果不存在则尝试从BIN或TXT文件转换并保存为PLY格式 # 读取COLMAP生成的稀疏点云数据优先从PLY文件读取如果不存在则尝试从BIN或TXT文件转换并保存为PLY格式
ply_path = os.path.join(path, "sparse/0/points3D.ply") ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin") bin_path = os.path.join(path, "sparse/0/points3D.bin")
txt_path = os.path.join(path, "sparse/0/points3D.txt") txt_path = os.path.join(path, "sparse/0/points3D.txt")
@ -206,10 +206,9 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
xyz, rgb, _ = read_points3D_text(txt_path) xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb) # 转换成ply文件 storePly(ply_path, xyz, rgb) # 转换成ply文件
# 读取PLY格式的稀疏点云
try: try:
pcd = fetchPly(ply_path) # points3D.ply读取COLMAP产生的稀疏点云 pcd = fetchPly(ply_path)
except: except:
pcd = None pcd = None

View File

@ -17,12 +17,17 @@ from utils.graphics_utils import fov2focal
WARNED = False WARNED = False
def loadCam(args, id, cam_info, resolution_scale): def loadCam(args, id, cam_info, resolution_scale):
"""
调整当前相机对应图像的分辨率并根据当前相机的info创建相机包含RTFovYFovX图像数据imageimage_pathimage_namewidthheight
"""
orig_w, orig_h = cam_info.image.size orig_w, orig_h = cam_info.image.size
# 1. 计算下采样后的图像尺寸
if args.resolution in [1, 2, 4, 8]: if args.resolution in [1, 2, 4, 8]:
# 计算下采样后的图像尺寸 [1, 1/2, 1/4, 1/8]
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
else: # should be a type that converts to float else:
if args.resolution == -1: if args.resolution == -1:
# 如果用户没有指定分辨率,即默认为-1则自动判断图片的宽度是>1.6K如果大于则自动进行下采样到1.6K时的采样倍率;如果小于,则采样倍率=1即使用原图尺寸
if orig_w > 1600: if orig_w > 1600:
global WARNED global WARNED
if not WARNED: if not WARNED:
@ -33,19 +38,22 @@ def loadCam(args, id, cam_info, resolution_scale):
else: else:
global_down = 1 global_down = 1
else: else:
# 如果用户指定了分辨率,则根据用户指定的分辨率计算采样倍率
global_down = orig_w / args.resolution global_down = orig_w / args.resolution
scale = float(global_down) * float(resolution_scale) scale = float(global_down) * float(resolution_scale) # 缩放倍率
resolution = (int(orig_w / scale), int(orig_h / scale)) resolution = (int(orig_w / scale), int(orig_h / scale)) # 下采样后的图像尺寸
resized_image_rgb = PILtoTorch(cam_info.image, resolution) # 调整图片比例归一化并转换通道为torch上的 (C, H, W) # 2. 调整图片分辨率归一化并转换通道为torch上的 (C, H, W)
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
gt_image = resized_image_rgb[:3, ...] gt_image = resized_image_rgb[:3, ...]
loaded_mask = None loaded_mask = None
if resized_image_rgb.shape[1] == 4: if resized_image_rgb.shape[1] == 4:
# 如果图片有alpha通道则提取出来
loaded_mask = resized_image_rgb[3:4, ...] loaded_mask = resized_image_rgb[3:4, ...]
# 3. 创建相机
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY, FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask, image=gt_image, gt_alpha_mask=loaded_mask,
@ -53,12 +61,12 @@ def loadCam(args, id, cam_info, resolution_scale):
def cameraList_from_camInfos(cam_infos, resolution_scale, args): def cameraList_from_camInfos(cam_infos, resolution_scale, args):
''' '''
cam_infos: 训练或测试相机对象列表 cam_infos: train或test相机info列表
resolution_scale: 不同分辨率列表 resolution_scale: 分辨率倍率
args: 高斯模型参数 args: 更新后的ModelParams()中的参数
''' '''
camera_list = [] camera_list = []
# 遍历每个camera_info包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height
for id, c in enumerate(cam_infos): for id, c in enumerate(cam_infos):
camera_list.append( loadCam(args, id, c, resolution_scale) ) camera_list.append( loadCam(args, id, c, resolution_scale) )