DreamCraft3D/threestudio/scripts/dreamcraft3d_dreambooth.py
2023-12-15 17:44:44 +08:00

54 lines
1.8 KiB
Python

import argparse
import os
from subprocess import run, CalledProcessError
import cv2
import glob
import numpy as np
import pytorch_lightning as pl
import torch
from tqdm import tqdm
from torchvision.utils import save_image
from threestudio.scripts.generate_mv_datasets import generate_mv_dataset
from threestudio.utils.config import load_config
import threestudio
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument("--action", default="both", help="action to perform", choices=["gen_data", "dreambooth", "both""])
args, extras = parser.parse_known_args()
return args, extras
def main(args, extras):
cfg = load_config(args.config, cli_args=extras, n_gpus=1)
if args.action == "gen_data" or args.action == "both":
# Generate multi-view dataset
generate_mv_dataset(cfg)
if args.action == "dreambooth" or args.action == "both":
# Run DreamBooth.
command = f'accelerate launch threestudio/scripts/train_dreambooth.py \
--pretrained_model_name_or_path="{cfg.custom_import.dreambooth.model_name}" \
--instance_data_dir="{cfg.custom_import.dreambooth.instance_dir}" \
--output_dir="{cfg.custom_import.dreambooth.output_dir}"\
--instance_prompt="{cfg.custom_import.dreambooth.prompt_dreambooth}" \
--resolution=512 \
--train_batch_size=2 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000'
os.system(command)
if __name__ == "__main__":
args, extras = parse_args()
main(args, extras)