measure fps

This commit is contained in:
liuzhi 2024-06-04 21:07:22 +08:00
parent 1a35943d4c
commit 9e82ac7170

View File

@ -21,6 +21,8 @@ from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import torch.utils.benchmark as benchmark
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
@ -47,6 +49,25 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
if not skip_test:
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
#
# measure_fps(scene, gaussians, pipeline, background)
#
# def render_fn(views, gaussians, pipeline, background):
# with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False):
# for view in views:
# render(view, gaussians, pipeline, background)
# def measure_fps(scene, gaussians, pipeline, background):
# with torch.no_grad():
# views = scene.getTrainCameras() + scene.getTestCameras()
# t0 = benchmark.Timer(stmt='render_fn(views, gaussians, pipeline, background)',
# setup='from __main__ import render_fn',
# globals={'views': views, 'gaussians': gaussians, 'pipeline': pipeline,
# 'background': background},
# )
# time = t0.timeit(50)
# fps = len(views)/time.median
# print("Rendering FPS: ", fps)
# return fps
if __name__ == "__main__":
# Set up command line argument parser