mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-05 02:25:45 +00:00
54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
import argparse
|
|
import os
|
|
from subprocess import run, CalledProcessError
|
|
|
|
import cv2
|
|
import glob
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from tqdm import tqdm
|
|
from torchvision.utils import save_image
|
|
|
|
from threestudio.scripts.generate_mv_datasets import generate_mv_dataset
|
|
from threestudio.utils.config import load_config
|
|
import threestudio
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", required=True, help="path to config file")
|
|
parser.add_argument("--action", default="both", help="action to perform", choices=["gen_data", "dreambooth", "both""])
|
|
args, extras = parser.parse_known_args()
|
|
return args, extras
|
|
|
|
|
|
def main(args, extras):
|
|
cfg = load_config(args.config, cli_args=extras, n_gpus=1)
|
|
|
|
if args.action == "gen_data" or args.action == "both":
|
|
# Generate multi-view dataset
|
|
generate_mv_dataset(cfg)
|
|
|
|
if args.action == "dreambooth" or args.action == "both":
|
|
# Run DreamBooth.
|
|
command = f'accelerate launch threestudio/scripts/train_dreambooth.py \
|
|
--pretrained_model_name_or_path="{cfg.custom_import.dreambooth.model_name}" \
|
|
--instance_data_dir="{cfg.custom_import.dreambooth.instance_dir}" \
|
|
--output_dir="{cfg.custom_import.dreambooth.output_dir}"\
|
|
--instance_prompt="{cfg.custom_import.dreambooth.prompt_dreambooth}" \
|
|
--resolution=512 \
|
|
--train_batch_size=2 \
|
|
--gradient_accumulation_steps=1 \
|
|
--learning_rate=1e-6 \
|
|
--lr_scheduler="constant" \
|
|
--lr_warmup_steps=0 \
|
|
--max_train_steps=1000'
|
|
|
|
os.system(command)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args, extras = parse_args()
|
|
main(args, extras)
|