DreamCraft3D/threestudio/scripts/generate_images_if.py
2023-12-15 17:44:44 +08:00

92 lines
3.2 KiB
Python

from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import torch
import os
import glob
import json
import argparse
import numpy as np
from tqdm import tqdm
SAVE_FOLDER = "./load/images_dreamfusion"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rank", default=0, type=int, help="# of GPU")
parser.add_argument("--prompt",required=True, type=str)
args = parser.parse_args()
# stage 1
stage_1 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
variant="fp16",
torch_dtype=torch.float16,
local_files_only=True
)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
local_files_only=True
)
# stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()
# stage 3
# safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
safety_modules = None
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler",
torch_dtype=torch.float16,
local_files_only=True
)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()
# # load prompt library
# with open(os.path.join("load/prompt_library.json"), "r") as f:
# prompt_library = json.load(f)
# n_prompts = len(prompt_library["dreamfusion"])
# n_prompts_per_rank = int(np.ceil(n_prompts / 8))
# for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]):
prompt = args.prompt
print("Prompt:", prompt)
save_folder = os.path.join(SAVE_FOLDER, prompt)
os.makedirs(save_folder, exist_ok=True)
# if len(glob.glob(f"{save_folder}/*.png")) >= 30:
# continue
# enhance prompt
prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning"
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
for _ in tqdm(range(30)):
seed = np.random.randint(low=0, high=10000000, size=1)[0]
generator = torch.manual_seed(seed)
### Stage 1
image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
# pt_to_pil(image)[0].save("./if_stage_I.png")
### Stage 2
image = stage_2(
image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images
# pt_to_pil(image)[0].save("./if_stage_II.png")
### Stage 3
image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images
image[0].save(f"{save_folder}/img_{seed:08d}.png")