DreamCraft3D/gradio_app.py
2023-12-15 17:44:44 +08:00

451 lines
14 KiB
Python

import argparse
import glob
import os
import re
import signal
import subprocess
import tempfile
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
import gradio as gr
import numpy as np
import psutil
import trimesh
def tail(f, window=20):
# Returns the last `window` lines of file `f`.
if window == 0:
return []
BUFSIZ = 1024
f.seek(0, 2)
remaining_bytes = f.tell()
size = window + 1
block = -1
data = []
while size > 0 and remaining_bytes > 0:
if remaining_bytes - BUFSIZ > 0:
# Seek back one whole BUFSIZ
f.seek(block * BUFSIZ, 2)
# read BUFFER
bunch = f.read(BUFSIZ)
else:
# file too small, start from beginning
f.seek(0, 0)
# only read what was not read
bunch = f.read(remaining_bytes)
bunch = bunch.decode("utf-8")
data.insert(0, bunch)
size -= bunch.count("\n")
remaining_bytes -= BUFSIZ
block -= 1
return "\n".join("".join(data).splitlines()[-window:])
@dataclass
class ExperimentStatus:
pid: Optional[int] = None
progress: str = ""
log: str = ""
output_image: Optional[str] = None
output_video: Optional[str] = None
output_mesh: Optional[str] = None
def tolist(self):
return [
self.pid,
self.progress,
self.log,
self.output_image,
self.output_video,
self.output_mesh,
]
EXP_ROOT_DIR = "outputs-gradio"
DEFAULT_PROMPT = "a delicious hamburger"
model_config = [
("DreamFusion (DeepFloyd-IF)", "configs/gradio/dreamfusion-if.yaml"),
("DreamFusion (Stable Diffusion)", "configs/gradio/dreamfusion-sd.yaml"),
("TextMesh (DeepFloyd-IF)", "configs/gradio/textmesh-if.yaml"),
("Fantasia3D (Stable Diffusion, Geometry Only)", "configs/gradio/fantasia3d.yaml"),
("SJC (Stable Diffusion)", "configs/gradio/sjc.yaml"),
("Latent-NeRF (Stable Diffusion)", "configs/gradio/latentnerf.yaml"),
]
model_choices = [m[0] for m in model_config]
model_name_to_config = {m[0]: m[1] for m in model_config}
def load_model_config(model_name):
return open(model_name_to_config[model_name]).read()
def load_model_config_attrs(model_name):
config_str = load_model_config(model_name)
from threestudio.utils.config import load_config
cfg = load_config(
config_str,
cli_args=[
"name=dummy",
"tag=dummy",
"use_timestamp=false",
f"exp_root_dir={EXP_ROOT_DIR}",
"system.prompt_processor.prompt=placeholder",
],
from_string=True,
)
return {
"source": config_str,
"guidance_scale": cfg.system.guidance.guidance_scale,
"max_steps": cfg.trainer.max_steps,
}
def on_model_selector_change(model_name):
cfg = load_model_config_attrs(model_name)
return [cfg["source"], cfg["guidance_scale"]]
def get_current_status(process, trial_dir, alive_path):
status = ExperimentStatus()
status.pid = process.pid
# write the current timestamp to the alive file
# the watcher will know the last active time of this process from this timestamp
if os.path.exists(os.path.dirname(alive_path)):
alive_fp = open(alive_path, "w")
alive_fp.seek(0)
alive_fp.write(str(time.time()))
alive_fp.flush()
log_path = os.path.join(trial_dir, "logs")
progress_path = os.path.join(trial_dir, "progress")
save_path = os.path.join(trial_dir, "save")
# read current progress from the progress file
# the progress file is created by GradioCallback
if os.path.exists(progress_path):
status.progress = open(progress_path).read()
else:
status.progress = "Setting up everything ..."
# read the last 10 lines of the log file
if os.path.exists(log_path):
status.log = tail(open(log_path, "rb"), window=10)
else:
status.log = ""
# get the validation image and testing video if they exist
if os.path.exists(save_path):
images = glob.glob(os.path.join(save_path, "*.png"))
steps = [
int(re.match(r"it(\d+)-0\.png", os.path.basename(f)).group(1))
for f in images
]
images = sorted(list(zip(images, steps)), key=lambda x: x[1])
if len(images) > 0:
status.output_image = images[-1][0]
videos = glob.glob(os.path.join(save_path, "*.mp4"))
steps = [
int(re.match(r"it(\d+)-test\.mp4", os.path.basename(f)).group(1))
for f in videos
]
videos = sorted(list(zip(videos, steps)), key=lambda x: x[1])
if len(videos) > 0:
status.output_video = videos[-1][0]
export_dirs = glob.glob(os.path.join(save_path, "*export"))
steps = [
int(re.match(r"it(\d+)-export", os.path.basename(f)).group(1))
for f in export_dirs
]
export_dirs = sorted(list(zip(export_dirs, steps)), key=lambda x: x[1])
if len(export_dirs) > 0:
obj = glob.glob(os.path.join(export_dirs[-1][0], "*.obj"))
if len(obj) > 0:
# FIXME
# seems the gr.Model3D cannot load our manually saved obj file
# here we load the obj and save it to a temporary file using trimesh
mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
trimesh.load(obj[0]).export(mesh_path.name)
status.output_mesh = mesh_path.name
return status
def run(
model_name: str,
config: str,
prompt: str,
guidance_scale: float,
seed: int,
max_steps: int,
):
# update status every 1 second
status_update_interval = 1
# save the config to a temporary file
config_file = tempfile.NamedTemporaryFile()
with open(config_file.name, "w") as f:
f.write(config)
# manually assign the output directory, name and tag so that we know the trial directory
name = os.path.basename(model_name_to_config[model_name]).split(".")[0]
tag = datetime.now().strftime("@%Y%m%d-%H%M%S")
trial_dir = os.path.join(EXP_ROOT_DIR, name, tag)
alive_path = os.path.join(trial_dir, "alive")
# spawn the training process
process = subprocess.Popen(
f"python launch.py --config {config_file.name} --train --gpu 0 --gradio trainer.enable_progress_bar=false".split()
+ [
f'name="{name}"',
f'tag="{tag}"',
f"exp_root_dir={EXP_ROOT_DIR}",
"use_timestamp=false",
f'system.prompt_processor.prompt="{prompt}"',
f"system.guidance.guidance_scale={guidance_scale}",
f"seed={seed}",
f"trainer.max_steps={max_steps}",
]
)
# spawn the watcher process
watch_process = subprocess.Popen(
"python gradio_app.py watch".split()
+ ["--pid", f"{process.pid}", "--trial-dir", f"{trial_dir}"]
)
# update status (progress, log, image, video) every status_update_interval senconds
# button status: Run -> Stop
while process.poll() is None:
time.sleep(status_update_interval)
yield get_current_status(process, trial_dir, alive_path).tolist() + [
gr.update(visible=False),
gr.update(value="Stop", variant="stop", visible=True),
]
# wait for the processes to finish
process.wait()
watch_process.wait()
# update status one last time
# button status: Stop / Reset -> Run
status = get_current_status(process, trial_dir, alive_path)
status.progress = "Finished."
yield status.tolist() + [
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
]
def stop_run(pid):
# kill the process
print(f"Trying to kill process {pid} ...")
try:
os.kill(pid, signal.SIGKILL)
except:
print(f"Exception when killing process {pid}.")
# button status: Stop -> Reset
return [
gr.update(value="Reset", variant="secondary", visible=True),
gr.update(visible=False),
]
def launch(port, listen=False):
with gr.Blocks(title="threestudio - Web Demo") as demo:
with gr.Row():
pid = gr.State()
with gr.Column(scale=1):
header = gr.Markdown(
"""
# threestudio
- Select a model from the dropdown menu.
- Input a text prompt.
- Hit Run!
"""
)
# model selection dropdown
model_selector = gr.Dropdown(
value=model_choices[0],
choices=model_choices,
label="Select a model",
)
# prompt input
prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Input prompt")
# guidance scale slider
guidance_scale_input = gr.Slider(
minimum=0.0,
maximum=100.0,
value=load_model_config_attrs(model_selector.value)[
"guidance_scale"
],
step=0.5,
label="Guidance scale",
)
# seed slider
seed_input = gr.Slider(
minimum=0, maximum=2147483647, value=0, step=1, label="Seed"
)
max_steps_input = gr.Slider(
minimum=1,
maximum=5000,
value=5000,
step=1,
label="Number of training steps",
)
# full config viewer
with gr.Accordion("See full configurations", open=False):
config_editor = gr.Code(
value=load_model_config(model_selector.value),
language="yaml",
interactive=False,
)
# load config on model selection change
model_selector.change(
fn=on_model_selector_change,
inputs=model_selector,
outputs=[config_editor, guidance_scale_input],
)
run_btn = gr.Button(value="Run", variant="primary")
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
# generation status
status = gr.Textbox(
value="Hit the Run button to start.",
label="Status",
lines=1,
max_lines=1,
)
with gr.Column(scale=1):
with gr.Accordion("See terminal logs", open=False):
# logs
logs = gr.Textbox(label="Logs", lines=10)
# validation image display
output_image = gr.Image(value=None, label="Image")
# testing video display
output_video = gr.Video(value=None, label="Video")
# export mesh display
output_mesh = gr.Model3D(value=None, label="3D Mesh")
run_event = run_btn.click(
fn=run,
inputs=[
model_selector,
config_editor,
prompt_input,
guidance_scale_input,
seed_input,
max_steps_input,
],
outputs=[
pid,
status,
logs,
output_image,
output_video,
output_mesh,
run_btn,
stop_btn,
],
)
stop_btn.click(
fn=stop_run, inputs=[pid], outputs=[run_btn, stop_btn], cancels=[run_event]
)
launch_args = {"server_port": port}
if listen:
launch_args["server_name"] = "0.0.0.0"
demo.queue().launch(**launch_args)
def watch(
pid: int, trial_dir: str, alive_timeout: int, wait_timeout: int, check_interval: int
) -> None:
print(f"Spawn watcher for process {pid}")
def timeout_handler(signum, frame):
exit(1)
alive_path = os.path.join(trial_dir, "alive")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(wait_timeout)
def loop_find_progress_file():
while True:
if not os.path.exists(alive_path):
time.sleep(check_interval)
else:
signal.alarm(0)
return
def loop_check_alive():
while True:
if not psutil.pid_exists(pid):
print(f"Process {pid} not exists, watcher exits.")
exit(0)
alive_timestamp = float(open(alive_path).read())
if time.time() - alive_timestamp > alive_timeout:
print(f"Alive timeout for process {pid}, killed.")
try:
os.kill(pid, signal.SIGKILL)
except:
print(f"Exception when killing process {pid}.")
exit(0)
time.sleep(check_interval)
# loop until alive file is found, or alive_timeout is reached
loop_find_progress_file()
# kill the process if it is not accessed for alive_timeout seconds
loop_check_alive()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("operation", type=str, choices=["launch", "watch"])
args, extra = parser.parse_known_args()
if args.operation == "launch":
parser.add_argument("--listen", action="store_true")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
launch(args.port, listen=args.listen)
if args.operation == "watch":
parser.add_argument("--pid", type=int)
parser.add_argument("--trial-dir", type=str)
parser.add_argument("--alive-timeout", type=int, default=10)
parser.add_argument("--wait-timeout", type=int, default=10)
parser.add_argument("--check-interval", type=int, default=1)
args = parser.parse_args()
watch(
args.pid,
args.trial_dir,
alive_timeout=args.alive_timeout,
wait_timeout=args.wait_timeout,
check_interval=args.check_interval,
)