mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
parent
f934e701b2
commit
4ea5609081
5
train.py
5
train.py
@ -208,7 +208,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--quiet", action="store_true")
|
parser.add_argument("--quiet", action="store_true")
|
||||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||||
parser.add_argument('--cuda_blocking', action='store_true', default=True)
|
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
args.save_iterations.append(args.iterations)
|
args.save_iterations.append(args.iterations)
|
||||||
|
|
||||||
@ -217,9 +217,6 @@ if __name__ == "__main__":
|
|||||||
# Initialize system state (RNG)
|
# Initialize system state (RNG)
|
||||||
safe_state(args.quiet)
|
safe_state(args.quiet)
|
||||||
|
|
||||||
# CUDA sometimes fails - option to disable asynchronous operations
|
|
||||||
if args.cuda_blocking:
|
|
||||||
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
|
||||||
# Start GUI server, configure and run training
|
# Start GUI server, configure and run training
|
||||||
network_gui.init(args.ip, args.port)
|
network_gui.init(args.ip, args.port)
|
||||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||||
|
@ -18,13 +18,15 @@ import random
|
|||||||
def inverse_sigmoid(x):
|
def inverse_sigmoid(x):
|
||||||
return torch.log(x/(1-x))
|
return torch.log(x/(1-x))
|
||||||
|
|
||||||
def PILtoTorch(pil_image, resolution):
|
def PILtoTorch(pil_image, resolution, pin_memory=True):
|
||||||
resized_image_PIL = pil_image.resize(resolution)
|
resized_image_PIL = pil_image.resize(resolution)
|
||||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
resized_image = torch.from_numpy(np.array(resized_image_PIL, dtype=np.float32)) / 255.0
|
||||||
if len(resized_image.shape) == 3:
|
if resized_image.ndim == 2:
|
||||||
return resized_image.permute(2, 0, 1)
|
resized_image = resized_image[None]
|
||||||
else:
|
resized_image = resized_image.permute(2, 0, 1).contiguous()
|
||||||
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
if pin_memory:
|
||||||
|
resized_image.pin_memory = True
|
||||||
|
return resized_image
|
||||||
|
|
||||||
def get_expon_lr_func(
|
def get_expon_lr_func(
|
||||||
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
||||||
|
Loading…
Reference in New Issue
Block a user