mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-06 05:55:06 +00:00
Merge 4ea5609081
into 8a70a8cd6f
This commit is contained in:
commit
7ece03d6a7
41
convert.py
41
convert.py
@ -8,12 +8,16 @@
|
||||
#
|
||||
# For inquiries contact george.drettakis@inria.fr
|
||||
#
|
||||
# xvdp removed magick, even single threaded PIL resizes 4X faster
|
||||
|
||||
|
||||
import os
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
import shutil
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# This Python script is based on the shell converter script provided in the MipNerF 360 repository.
|
||||
parser = ArgumentParser("Colmap converter")
|
||||
parser.add_argument("--no_gpu", action='store_true')
|
||||
@ -25,7 +29,7 @@ parser.add_argument("--resize", action="store_true")
|
||||
parser.add_argument("--magick_executable", default="", type=str)
|
||||
args = parser.parse_args()
|
||||
colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
|
||||
magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
|
||||
|
||||
use_gpu = 1 if not args.no_gpu else 0
|
||||
|
||||
if not args.skip_matching:
|
||||
@ -87,38 +91,21 @@ for file in files:
|
||||
destination_file = os.path.join(args.source_path, "sparse", "0", file)
|
||||
shutil.move(source_file, destination_file)
|
||||
|
||||
if(args.resize):
|
||||
if args.resize:
|
||||
print("Copying and resizing...")
|
||||
|
||||
# Resize images.
|
||||
os.makedirs(args.source_path + "/images_2", exist_ok=True)
|
||||
os.makedirs(args.source_path + "/images_4", exist_ok=True)
|
||||
os.makedirs(args.source_path + "/images_8", exist_ok=True)
|
||||
for div in [2,4,8]:
|
||||
os.makedirs(args.source_path + f"/images_{div}", exist_ok=True)
|
||||
# Get the list of files in the source directory
|
||||
files = os.listdir(args.source_path + "/images")
|
||||
# Copy each file from the source directory to the destination directory
|
||||
for file in files:
|
||||
for j, file in enumerate(files):
|
||||
source_file = os.path.join(args.source_path, "images", file)
|
||||
|
||||
destination_file = os.path.join(args.source_path, "images_2", file)
|
||||
shutil.copy2(source_file, destination_file)
|
||||
exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
|
||||
if exit_code != 0:
|
||||
logging.error(f"50% resize failed with code {exit_code}. Exiting.")
|
||||
exit(exit_code)
|
||||
|
||||
destination_file = os.path.join(args.source_path, "images_4", file)
|
||||
shutil.copy2(source_file, destination_file)
|
||||
exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
|
||||
if exit_code != 0:
|
||||
logging.error(f"25% resize failed with code {exit_code}. Exiting.")
|
||||
exit(exit_code)
|
||||
|
||||
destination_file = os.path.join(args.source_path, "images_8", file)
|
||||
shutil.copy2(source_file, destination_file)
|
||||
exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
|
||||
if exit_code != 0:
|
||||
logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
|
||||
exit(exit_code)
|
||||
im = Image.open(source_file)
|
||||
logging.info(f"processing image [{j}/{len(files)}] {source_file}")
|
||||
for div in [2,4,8]:
|
||||
destination_file = os.path.join(args.source_path, f"images_{div}", file)
|
||||
im.resize([round(i/div) for i in im.size], Image.BICUBIC).save(destination_file, quality=100)
|
||||
|
||||
print("Done.")
|
||||
|
@ -56,6 +56,16 @@ class Camera(nn.Module):
|
||||
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
||||
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '()\n'
|
||||
for k, v in self.__dict__.items():
|
||||
if torch.is_tensor(v) and v.numel() > 16:
|
||||
format_string +=f" {k}:\t{tuple(v.shape)}\n"
|
||||
else:
|
||||
format_string += f"{k}:\t{v}\n"
|
||||
return format_string
|
||||
|
||||
|
||||
class MiniCam:
|
||||
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
||||
self.image_width = width
|
||||
|
@ -58,6 +58,17 @@ class GaussianModel:
|
||||
self.spatial_lr_scale = 0
|
||||
self.setup_functions()
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '()\n'
|
||||
for k, v in self.__dict__.items():
|
||||
if torch.is_tensor(v):
|
||||
format_string +=f" {k}:\t{tuple(v.shape)}\n"
|
||||
else:
|
||||
format_string += f"{k}:\t{v}\n"
|
||||
return format_string
|
||||
|
||||
|
||||
def capture(self):
|
||||
return (
|
||||
self.active_sh_degree,
|
||||
@ -73,7 +84,7 @@ class GaussianModel:
|
||||
self.optimizer.state_dict(),
|
||||
self.spatial_lr_scale,
|
||||
)
|
||||
|
||||
|
||||
def restore(self, model_args, training_args):
|
||||
(self.active_sh_degree,
|
||||
self._xyz,
|
||||
|
15
train.py
15
train.py
@ -48,7 +48,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
ema_loss_for_log = 0.0
|
||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||
first_iter += 1
|
||||
for iteration in range(first_iter, opt.iterations + 1):
|
||||
for iteration in range(first_iter, opt.iterations + 1):
|
||||
if network_gui.conn == None:
|
||||
network_gui.try_connect()
|
||||
while network_gui.conn != None:
|
||||
@ -62,7 +62,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
|
||||
break
|
||||
except Exception as e:
|
||||
network_gui.conn.close()
|
||||
network_gui.conn = None
|
||||
network_gui.listener.close()
|
||||
network_gui.listener = None
|
||||
|
||||
iter_start.record()
|
||||
|
||||
@ -162,7 +165,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
|
||||
# Report test and samples of training set
|
||||
if iteration in testing_iterations:
|
||||
torch.cuda.empty_cache()
|
||||
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
|
||||
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
|
||||
{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
|
||||
|
||||
for config in validation_configs:
|
||||
@ -203,11 +206,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--quiet", action="store_true")
|
||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
args.save_iterations.append(args.iterations)
|
||||
|
||||
|
||||
print("Optimizing " + args.model_path)
|
||||
|
||||
# Initialize system state (RNG)
|
||||
@ -216,7 +220,8 @@ if __name__ == "__main__":
|
||||
# Start GUI server, configure and run training
|
||||
network_gui.init(args.ip, args.port)
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations,
|
||||
args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
|
||||
# All done
|
||||
print("\nTraining complete.")
|
||||
|
@ -18,13 +18,15 @@ import random
|
||||
def inverse_sigmoid(x):
|
||||
return torch.log(x/(1-x))
|
||||
|
||||
def PILtoTorch(pil_image, resolution):
|
||||
def PILtoTorch(pil_image, resolution, pin_memory=True):
|
||||
resized_image_PIL = pil_image.resize(resolution)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
||||
if len(resized_image.shape) == 3:
|
||||
return resized_image.permute(2, 0, 1)
|
||||
else:
|
||||
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL, dtype=np.float32)) / 255.0
|
||||
if resized_image.ndim == 2:
|
||||
resized_image = resized_image[None]
|
||||
resized_image = resized_image.permute(2, 0, 1).contiguous()
|
||||
if pin_memory:
|
||||
resized_image.pin_memory = True
|
||||
return resized_image
|
||||
|
||||
def get_expon_lr_func(
|
||||
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
||||
|
Loading…
Reference in New Issue
Block a user