add comments

This commit is contained in:
liuzhi 2024-07-03 20:43:45 +08:00
parent 268f92fafc
commit b405eb02e7
3 changed files with 50 additions and 48 deletions

View File

@ -26,82 +26,77 @@ class Scene:
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""
初始化3D场景对象
初始化3D场景对象
args: 存储着与 GaussianMoedl 相关参数 的args即包含scene/__init__.py/ModelParams()中的参数
gaussians: 3D高斯模型对象用于场景点的3D表示
load_iteration: 指定加载模型的迭代次数如果是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型如果不是None且不是-1则加载指定迭代次数的
shuffle: 是否在训练前打乱相机列表
load_iteration: 训练时为 Nonerender时为 指定的iteration如果是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型如果不是None且不是-1则加载指定迭代次数的
shuffle: 是否在训练前打乱相机列表render.py中会设为False
resolution_scales: 分辨率比例列表用于处理不同分辨率的相机
"""
self.model_path = args.model_path # 模型文件保存路径
self.loaded_iter = None # 已加载的迭代次数
self.gaussians = gaussians # 高斯模型对象
# 如果已有训练模型,则加载
# 1. 如果指定了加载模型的迭代次数则赋给Scene.loaded_iter
if load_iteration:
if load_iteration == -1:
# 是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
if load_iteration == -1: # 是-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else:
# 不是None且不是-1则加载指定迭代次数的
else: # 不是None且不是-1则加载指定迭代次数的
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {} # 用于训练的相机
self.test_cameras = {} # 用于测试的相机
# 从COLMAP或Blender的输出结果中构建 场景信息(包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径)
# 2. 从COLMAP或Blender的输出结果中构建 scene_info包含 点云、train相机info、test相机info、场景归一化参数、点云文件路径)
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
else:
assert False, "Could not recognize scene type!"
assert False, "Could not recognize scene type! {}".format(os.path.join(args.source_path, "sparse"))
if not self.loaded_iter:
# 如果没有加载模型则将点云文件point3D.ply文件复制到input.ply文件
# 未加载已训练模型,则:
# (1) 将点云文件point3D.ply文件复制到input.ply文件
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
dest_file.write(src_file.read())
# (2) 将相机参数写入cameras.json文件
json_cams = []
camlist = []
if scene_info.test_cameras:
# 测试相机添加到 camlist 中
if scene_info.test_cameras: # 添加test相机到 camlist 中
camlist.extend(scene_info.test_cameras)
if scene_info.train_cameras:
# 训练相机添加到 camlist 中
if scene_info.train_cameras: # 添加train相机到 camlist 中
camlist.extend(scene_info.train_cameras)
# 遍历 camlist 中的所有相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
# 遍历 camlist 中的所有训练和测试相机,使用 camera_to_JSON 函数将每个相机转换为 JSON 格式,并添加到 json_cams 列表中,并将 json_cams 写入 cameras.json 文件中
for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam))
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
json.dump(json_cams, file)
if shuffle:
# 随机打乱训练和测试相机列表
# 3. 随机打乱train、test相机info
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
# 根据resolution_scales加载不同分辨率的训练和测试相机包含R、T、视场角
# 4. 调整图片分辨率并根据train、test相机info(包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)创建相机
for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
# 5. 加载高斯体
if self.loaded_iter:
# 如果加载已训练模型,则直接读取对应(已经迭代出来的)场景
self.gaussians.load_ply(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
# 如果加载已训练模型,则直接读取其已训练的高斯体
self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))
else:
# 不加载训练模型,则调用 GaussianModel.create_from_pcd 从稀疏点云 scene_info.point_cloud 中建立模型
# 不加载,则从稀疏点云中初始建立高斯体
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
def save(self, iteration):

View File

@ -96,8 +96,8 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
if intr.model=="SIMPLE_PINHOLE" or intr.model=="SIMPLE_RADIAL":
# 如果是简单针孔模型,只有一个焦距参数
focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角: 2 * arctan(H / 2fy))
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角: 2 * arctan(W / 2fx)
elif intr.model=="PINHOLE":
# 如果是针孔模型,有两个焦距参数
focal_length_x = intr.params[0]
@ -115,7 +115,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
continue
image = Image.open(image_path) # PIL.Image读取的为RGB格式OpenCV读取的为BGR格式
# 创建相机信息类CameraInfo对象 (包含旋转矩阵、平移向量、视场角、图像数据、图片路径、图片名、宽度、高度)并添加到列表cam_infos中
# 创建相机信息类CameraInfo对象 (包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height)并添加到列表cam_infos中
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=width, height=height)
cam_infos.append(cam_info)
@ -155,7 +155,7 @@ def storePly(path, xyz, rgb):
def readColmapSceneInfo(path, images, eval, llffhold=8):
'''
加载COLMAP的结果中的二进制相机外参文件imags.bin 内参文件cameras.bin
path: GaussianModel中的源文件路径
path: source_path
images: 'images'
eval: 是否为eval模式
llffhold: 默认为8
@ -181,20 +181,20 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
# 根据图片名称排序,以保证顺序一致性
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : (x.image_path.split('/')[-2], int(x.image_name)))
# 根据是否为评估模式eval将相机分为训练集和测试集
# 如果为评估模式每llffhold张图片取一张作为测试集
# 根据是否eval将相机分为训练集和测试集
if eval:
# 若要评测则每llffhold张图片取一张作为测试集
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
else:
# 如果不是评估模式,所有相机均为训练相机,测试相机列表为空
# 不评测,则所有相机均为训练相机,测试相机列表为空
train_cam_infos = cam_infos
test_cam_infos = []
# 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定
nerf_normalization = getNerfppNorm(train_cam_infos)
# 尝试读取COLMAP生成的稀疏点云数据优先从PLY文件读取如果不存在则尝试从BIN或TXT文件转换并保存为PLY格式
# 读取COLMAP生成的稀疏点云数据优先从PLY文件读取如果不存在则尝试从BIN或TXT文件转换并保存为PLY格式
ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin")
txt_path = os.path.join(path, "sparse/0/points3D.txt")
@ -206,10 +206,9 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb) # 转换成ply文件
# 读取PLY格式的稀疏点云
try:
pcd = fetchPly(ply_path) # points3D.ply读取COLMAP产生的稀疏点云
pcd = fetchPly(ply_path)
except:
pcd = None

View File

@ -17,12 +17,17 @@ from utils.graphics_utils import fov2focal
WARNED = False
def loadCam(args, id, cam_info, resolution_scale):
"""
调整当前相机对应图像的分辨率并根据当前相机的info创建相机包含RTFovYFovX图像数据imageimage_pathimage_namewidthheight
"""
orig_w, orig_h = cam_info.image.size
# 1. 计算下采样后的图像尺寸
if args.resolution in [1, 2, 4, 8]:
# 计算下采样后的图像尺寸 [1, 1/2, 1/4, 1/8]
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
else: # should be a type that converts to float
else:
if args.resolution == -1:
# 如果用户没有指定分辨率,即默认为-1则自动判断图片的宽度是>1.6K如果大于则自动进行下采样到1.6K时的采样倍率;如果小于,则采样倍率=1即使用原图尺寸
if orig_w > 1600:
global WARNED
if not WARNED:
@ -33,19 +38,22 @@ def loadCam(args, id, cam_info, resolution_scale):
else:
global_down = 1
else:
# 如果用户指定了分辨率,则根据用户指定的分辨率计算采样倍率
global_down = orig_w / args.resolution
scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale))
scale = float(global_down) * float(resolution_scale) # 缩放倍率
resolution = (int(orig_w / scale), int(orig_h / scale)) # 下采样后的图像尺寸
resized_image_rgb = PILtoTorch(cam_info.image, resolution) # 调整图片比例归一化并转换通道为torch上的 (C, H, W)
# 2. 调整图片分辨率归一化并转换通道为torch上的 (C, H, W)
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
gt_image = resized_image_rgb[:3, ...]
loaded_mask = None
if resized_image_rgb.shape[1] == 4:
# 如果图片有alpha通道则提取出来
loaded_mask = resized_image_rgb[3:4, ...]
# 3. 创建相机
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
@ -53,14 +61,14 @@ def loadCam(args, id, cam_info, resolution_scale):
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
'''
cam_infos: 训练或测试相机对象列表
resolution_scale: 不同分辨率列表
args: 高斯模型参数
cam_infos: train或test相机info列表
resolution_scale: 分辨率倍率
args: 更新后的ModelParams()中的参数
'''
camera_list = []
# 遍历每个camera_info包含R、T、FovY、FovX、图像数据image、image_path、image_name、width、height
for id, c in enumerate(cam_infos):
camera_list.append(loadCam(args, id, c, resolution_scale))
camera_list.append( loadCam(args, id, c, resolution_scale) )
return camera_list