mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-04 18:15:11 +00:00
451 lines
14 KiB
Python
451 lines
14 KiB
Python
import argparse
|
|
import glob
|
|
import os
|
|
import re
|
|
import signal
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import gradio as gr
|
|
import numpy as np
|
|
import psutil
|
|
import trimesh
|
|
|
|
|
|
def tail(f, window=20):
|
|
# Returns the last `window` lines of file `f`.
|
|
if window == 0:
|
|
return []
|
|
|
|
BUFSIZ = 1024
|
|
f.seek(0, 2)
|
|
remaining_bytes = f.tell()
|
|
size = window + 1
|
|
block = -1
|
|
data = []
|
|
|
|
while size > 0 and remaining_bytes > 0:
|
|
if remaining_bytes - BUFSIZ > 0:
|
|
# Seek back one whole BUFSIZ
|
|
f.seek(block * BUFSIZ, 2)
|
|
# read BUFFER
|
|
bunch = f.read(BUFSIZ)
|
|
else:
|
|
# file too small, start from beginning
|
|
f.seek(0, 0)
|
|
# only read what was not read
|
|
bunch = f.read(remaining_bytes)
|
|
|
|
bunch = bunch.decode("utf-8")
|
|
data.insert(0, bunch)
|
|
size -= bunch.count("\n")
|
|
remaining_bytes -= BUFSIZ
|
|
block -= 1
|
|
|
|
return "\n".join("".join(data).splitlines()[-window:])
|
|
|
|
|
|
@dataclass
|
|
class ExperimentStatus:
|
|
pid: Optional[int] = None
|
|
progress: str = ""
|
|
log: str = ""
|
|
output_image: Optional[str] = None
|
|
output_video: Optional[str] = None
|
|
output_mesh: Optional[str] = None
|
|
|
|
def tolist(self):
|
|
return [
|
|
self.pid,
|
|
self.progress,
|
|
self.log,
|
|
self.output_image,
|
|
self.output_video,
|
|
self.output_mesh,
|
|
]
|
|
|
|
|
|
EXP_ROOT_DIR = "outputs-gradio"
|
|
DEFAULT_PROMPT = "a delicious hamburger"
|
|
model_config = [
|
|
("DreamFusion (DeepFloyd-IF)", "configs/gradio/dreamfusion-if.yaml"),
|
|
("DreamFusion (Stable Diffusion)", "configs/gradio/dreamfusion-sd.yaml"),
|
|
("TextMesh (DeepFloyd-IF)", "configs/gradio/textmesh-if.yaml"),
|
|
("Fantasia3D (Stable Diffusion, Geometry Only)", "configs/gradio/fantasia3d.yaml"),
|
|
("SJC (Stable Diffusion)", "configs/gradio/sjc.yaml"),
|
|
("Latent-NeRF (Stable Diffusion)", "configs/gradio/latentnerf.yaml"),
|
|
]
|
|
model_choices = [m[0] for m in model_config]
|
|
model_name_to_config = {m[0]: m[1] for m in model_config}
|
|
|
|
|
|
def load_model_config(model_name):
|
|
return open(model_name_to_config[model_name]).read()
|
|
|
|
|
|
def load_model_config_attrs(model_name):
|
|
config_str = load_model_config(model_name)
|
|
from threestudio.utils.config import load_config
|
|
|
|
cfg = load_config(
|
|
config_str,
|
|
cli_args=[
|
|
"name=dummy",
|
|
"tag=dummy",
|
|
"use_timestamp=false",
|
|
f"exp_root_dir={EXP_ROOT_DIR}",
|
|
"system.prompt_processor.prompt=placeholder",
|
|
],
|
|
from_string=True,
|
|
)
|
|
return {
|
|
"source": config_str,
|
|
"guidance_scale": cfg.system.guidance.guidance_scale,
|
|
"max_steps": cfg.trainer.max_steps,
|
|
}
|
|
|
|
|
|
def on_model_selector_change(model_name):
|
|
cfg = load_model_config_attrs(model_name)
|
|
return [cfg["source"], cfg["guidance_scale"]]
|
|
|
|
|
|
def get_current_status(process, trial_dir, alive_path):
|
|
status = ExperimentStatus()
|
|
|
|
status.pid = process.pid
|
|
|
|
# write the current timestamp to the alive file
|
|
# the watcher will know the last active time of this process from this timestamp
|
|
if os.path.exists(os.path.dirname(alive_path)):
|
|
alive_fp = open(alive_path, "w")
|
|
alive_fp.seek(0)
|
|
alive_fp.write(str(time.time()))
|
|
alive_fp.flush()
|
|
|
|
log_path = os.path.join(trial_dir, "logs")
|
|
progress_path = os.path.join(trial_dir, "progress")
|
|
save_path = os.path.join(trial_dir, "save")
|
|
|
|
# read current progress from the progress file
|
|
# the progress file is created by GradioCallback
|
|
if os.path.exists(progress_path):
|
|
status.progress = open(progress_path).read()
|
|
else:
|
|
status.progress = "Setting up everything ..."
|
|
|
|
# read the last 10 lines of the log file
|
|
if os.path.exists(log_path):
|
|
status.log = tail(open(log_path, "rb"), window=10)
|
|
else:
|
|
status.log = ""
|
|
|
|
# get the validation image and testing video if they exist
|
|
if os.path.exists(save_path):
|
|
images = glob.glob(os.path.join(save_path, "*.png"))
|
|
steps = [
|
|
int(re.match(r"it(\d+)-0\.png", os.path.basename(f)).group(1))
|
|
for f in images
|
|
]
|
|
images = sorted(list(zip(images, steps)), key=lambda x: x[1])
|
|
if len(images) > 0:
|
|
status.output_image = images[-1][0]
|
|
|
|
videos = glob.glob(os.path.join(save_path, "*.mp4"))
|
|
steps = [
|
|
int(re.match(r"it(\d+)-test\.mp4", os.path.basename(f)).group(1))
|
|
for f in videos
|
|
]
|
|
videos = sorted(list(zip(videos, steps)), key=lambda x: x[1])
|
|
if len(videos) > 0:
|
|
status.output_video = videos[-1][0]
|
|
|
|
export_dirs = glob.glob(os.path.join(save_path, "*export"))
|
|
steps = [
|
|
int(re.match(r"it(\d+)-export", os.path.basename(f)).group(1))
|
|
for f in export_dirs
|
|
]
|
|
export_dirs = sorted(list(zip(export_dirs, steps)), key=lambda x: x[1])
|
|
if len(export_dirs) > 0:
|
|
obj = glob.glob(os.path.join(export_dirs[-1][0], "*.obj"))
|
|
if len(obj) > 0:
|
|
# FIXME
|
|
# seems the gr.Model3D cannot load our manually saved obj file
|
|
# here we load the obj and save it to a temporary file using trimesh
|
|
mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
|
|
trimesh.load(obj[0]).export(mesh_path.name)
|
|
status.output_mesh = mesh_path.name
|
|
|
|
return status
|
|
|
|
|
|
def run(
|
|
model_name: str,
|
|
config: str,
|
|
prompt: str,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
max_steps: int,
|
|
):
|
|
# update status every 1 second
|
|
status_update_interval = 1
|
|
|
|
# save the config to a temporary file
|
|
config_file = tempfile.NamedTemporaryFile()
|
|
|
|
with open(config_file.name, "w") as f:
|
|
f.write(config)
|
|
|
|
# manually assign the output directory, name and tag so that we know the trial directory
|
|
name = os.path.basename(model_name_to_config[model_name]).split(".")[0]
|
|
tag = datetime.now().strftime("@%Y%m%d-%H%M%S")
|
|
trial_dir = os.path.join(EXP_ROOT_DIR, name, tag)
|
|
alive_path = os.path.join(trial_dir, "alive")
|
|
|
|
# spawn the training process
|
|
process = subprocess.Popen(
|
|
f"python launch.py --config {config_file.name} --train --gpu 0 --gradio trainer.enable_progress_bar=false".split()
|
|
+ [
|
|
f'name="{name}"',
|
|
f'tag="{tag}"',
|
|
f"exp_root_dir={EXP_ROOT_DIR}",
|
|
"use_timestamp=false",
|
|
f'system.prompt_processor.prompt="{prompt}"',
|
|
f"system.guidance.guidance_scale={guidance_scale}",
|
|
f"seed={seed}",
|
|
f"trainer.max_steps={max_steps}",
|
|
]
|
|
)
|
|
|
|
# spawn the watcher process
|
|
watch_process = subprocess.Popen(
|
|
"python gradio_app.py watch".split()
|
|
+ ["--pid", f"{process.pid}", "--trial-dir", f"{trial_dir}"]
|
|
)
|
|
|
|
# update status (progress, log, image, video) every status_update_interval senconds
|
|
# button status: Run -> Stop
|
|
while process.poll() is None:
|
|
time.sleep(status_update_interval)
|
|
yield get_current_status(process, trial_dir, alive_path).tolist() + [
|
|
gr.update(visible=False),
|
|
gr.update(value="Stop", variant="stop", visible=True),
|
|
]
|
|
|
|
# wait for the processes to finish
|
|
process.wait()
|
|
watch_process.wait()
|
|
|
|
# update status one last time
|
|
# button status: Stop / Reset -> Run
|
|
status = get_current_status(process, trial_dir, alive_path)
|
|
status.progress = "Finished."
|
|
yield status.tolist() + [
|
|
gr.update(value="Run", variant="primary", visible=True),
|
|
gr.update(visible=False),
|
|
]
|
|
|
|
|
|
def stop_run(pid):
|
|
# kill the process
|
|
print(f"Trying to kill process {pid} ...")
|
|
try:
|
|
os.kill(pid, signal.SIGKILL)
|
|
except:
|
|
print(f"Exception when killing process {pid}.")
|
|
# button status: Stop -> Reset
|
|
return [
|
|
gr.update(value="Reset", variant="secondary", visible=True),
|
|
gr.update(visible=False),
|
|
]
|
|
|
|
|
|
def launch(port, listen=False):
|
|
with gr.Blocks(title="threestudio - Web Demo") as demo:
|
|
with gr.Row():
|
|
pid = gr.State()
|
|
with gr.Column(scale=1):
|
|
header = gr.Markdown(
|
|
"""
|
|
# threestudio
|
|
|
|
- Select a model from the dropdown menu.
|
|
- Input a text prompt.
|
|
- Hit Run!
|
|
"""
|
|
)
|
|
|
|
# model selection dropdown
|
|
model_selector = gr.Dropdown(
|
|
value=model_choices[0],
|
|
choices=model_choices,
|
|
label="Select a model",
|
|
)
|
|
|
|
# prompt input
|
|
prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Input prompt")
|
|
|
|
# guidance scale slider
|
|
guidance_scale_input = gr.Slider(
|
|
minimum=0.0,
|
|
maximum=100.0,
|
|
value=load_model_config_attrs(model_selector.value)[
|
|
"guidance_scale"
|
|
],
|
|
step=0.5,
|
|
label="Guidance scale",
|
|
)
|
|
|
|
# seed slider
|
|
seed_input = gr.Slider(
|
|
minimum=0, maximum=2147483647, value=0, step=1, label="Seed"
|
|
)
|
|
|
|
max_steps_input = gr.Slider(
|
|
minimum=1,
|
|
maximum=5000,
|
|
value=5000,
|
|
step=1,
|
|
label="Number of training steps",
|
|
)
|
|
|
|
# full config viewer
|
|
with gr.Accordion("See full configurations", open=False):
|
|
config_editor = gr.Code(
|
|
value=load_model_config(model_selector.value),
|
|
language="yaml",
|
|
interactive=False,
|
|
)
|
|
|
|
# load config on model selection change
|
|
model_selector.change(
|
|
fn=on_model_selector_change,
|
|
inputs=model_selector,
|
|
outputs=[config_editor, guidance_scale_input],
|
|
)
|
|
|
|
run_btn = gr.Button(value="Run", variant="primary")
|
|
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
|
|
|
|
# generation status
|
|
status = gr.Textbox(
|
|
value="Hit the Run button to start.",
|
|
label="Status",
|
|
lines=1,
|
|
max_lines=1,
|
|
)
|
|
|
|
with gr.Column(scale=1):
|
|
with gr.Accordion("See terminal logs", open=False):
|
|
# logs
|
|
logs = gr.Textbox(label="Logs", lines=10)
|
|
|
|
# validation image display
|
|
output_image = gr.Image(value=None, label="Image")
|
|
|
|
# testing video display
|
|
output_video = gr.Video(value=None, label="Video")
|
|
|
|
# export mesh display
|
|
output_mesh = gr.Model3D(value=None, label="3D Mesh")
|
|
|
|
run_event = run_btn.click(
|
|
fn=run,
|
|
inputs=[
|
|
model_selector,
|
|
config_editor,
|
|
prompt_input,
|
|
guidance_scale_input,
|
|
seed_input,
|
|
max_steps_input,
|
|
],
|
|
outputs=[
|
|
pid,
|
|
status,
|
|
logs,
|
|
output_image,
|
|
output_video,
|
|
output_mesh,
|
|
run_btn,
|
|
stop_btn,
|
|
],
|
|
)
|
|
stop_btn.click(
|
|
fn=stop_run, inputs=[pid], outputs=[run_btn, stop_btn], cancels=[run_event]
|
|
)
|
|
|
|
launch_args = {"server_port": port}
|
|
if listen:
|
|
launch_args["server_name"] = "0.0.0.0"
|
|
demo.queue().launch(**launch_args)
|
|
|
|
|
|
def watch(
|
|
pid: int, trial_dir: str, alive_timeout: int, wait_timeout: int, check_interval: int
|
|
) -> None:
|
|
print(f"Spawn watcher for process {pid}")
|
|
|
|
def timeout_handler(signum, frame):
|
|
exit(1)
|
|
|
|
alive_path = os.path.join(trial_dir, "alive")
|
|
signal.signal(signal.SIGALRM, timeout_handler)
|
|
signal.alarm(wait_timeout)
|
|
|
|
def loop_find_progress_file():
|
|
while True:
|
|
if not os.path.exists(alive_path):
|
|
time.sleep(check_interval)
|
|
else:
|
|
signal.alarm(0)
|
|
return
|
|
|
|
def loop_check_alive():
|
|
while True:
|
|
if not psutil.pid_exists(pid):
|
|
print(f"Process {pid} not exists, watcher exits.")
|
|
exit(0)
|
|
alive_timestamp = float(open(alive_path).read())
|
|
if time.time() - alive_timestamp > alive_timeout:
|
|
print(f"Alive timeout for process {pid}, killed.")
|
|
try:
|
|
os.kill(pid, signal.SIGKILL)
|
|
except:
|
|
print(f"Exception when killing process {pid}.")
|
|
exit(0)
|
|
time.sleep(check_interval)
|
|
|
|
# loop until alive file is found, or alive_timeout is reached
|
|
loop_find_progress_file()
|
|
# kill the process if it is not accessed for alive_timeout seconds
|
|
loop_check_alive()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("operation", type=str, choices=["launch", "watch"])
|
|
args, extra = parser.parse_known_args()
|
|
if args.operation == "launch":
|
|
parser.add_argument("--listen", action="store_true")
|
|
parser.add_argument("--port", type=int, default=7860)
|
|
args = parser.parse_args()
|
|
launch(args.port, listen=args.listen)
|
|
if args.operation == "watch":
|
|
parser.add_argument("--pid", type=int)
|
|
parser.add_argument("--trial-dir", type=str)
|
|
parser.add_argument("--alive-timeout", type=int, default=10)
|
|
parser.add_argument("--wait-timeout", type=int, default=10)
|
|
parser.add_argument("--check-interval", type=int, default=1)
|
|
args = parser.parse_args()
|
|
watch(
|
|
args.pid,
|
|
args.trial_dir,
|
|
alive_timeout=args.alive_timeout,
|
|
wait_timeout=args.wait_timeout,
|
|
check_interval=args.check_interval,
|
|
)
|