gaussian-splatting/scene/__init__.py
2024-05-09 20:17:55 +08:00

118 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
class Scene:
"""
Scene 类用于管理场景的3D模型包括相机参数、点云数据和高斯模型的初始化和加载
"""
gaussians : GaussianModel
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""
初始化场景对象
:param args: 包含模型路径和源路径等模型参数
:param gaussians: 高斯模型对象用于场景点的3D表示
:param load_iteration: 指定加载模型的迭代次数如果不为None且为-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,且不为-1则加载指定迭代次数的
:param shuffle: 是否在训练前打乱相机列表
:param resolution_scales: 分辨率比例列表,用于处理不同分辨率的相机
"""
self.model_path = args.model_path # 模型文件保存路径
self.loaded_iter = None # 已加载的迭代次数
self.gaussians = gaussians # 高斯模型对象
# 检查并加载已有的训练模型
if load_iteration:
# 不为None
if load_iteration == -1:
# 且为-1则在输出文件夹下的point_cloud/文件夹下搜索迭代次数最大的模型,记录最大迭代次数
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else:
# 不为-1则加载指定迭代次数的
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {} # 用于训练的相机参数
self.test_cameras = {} # 用于测试的相机参数
# 判断数据集类型是COLMAP的输出还是Blender得输出并从中加载场景信息
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
else:
assert False, "Could not recognize scene type!"
# loaded_iter = None模型还未训练过
if not self.loaded_iter:
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
dest_file.write(src_file.read())
json_cams = []
camlist = []
if scene_info.test_cameras:
camlist.extend(scene_info.test_cameras)
if scene_info.train_cameras:
camlist.extend(scene_info.train_cameras)
for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam))
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
json.dump(json_cams, file)
if shuffle:
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
# 根据resolution_scales加载不同分辨率的训练和测试位姿
for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
if self.loaded_iter:
# 直接读取对应(已经迭代出来的)场景
self.gaussians.load_ply(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
else:
# loaded_iter = None模型还未训练过调用GaussianModel.create_from_pcd从scene_info.point_cloud中建立模型
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
def save(self, iteration):
"""
保存当前迭代下的3D高斯模型点云。
iteration: 当前的迭代次数
"""
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
def getTrainCameras(self, scale=1.0):
"""
获取指定分辨率比例的训练相机列表
scale: 分辨率比例
return: 指定分辨率比例的训练相机列表
"""
return self.train_cameras[scale]
def getTestCameras(self, scale=1.0):
return self.test_cameras[scale]