mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-05 02:25:45 +00:00
156 lines
4.8 KiB
Python
156 lines
4.8 KiB
Python
import gc
|
|
import os
|
|
import re
|
|
|
|
import tinycudann as tcnn
|
|
import torch
|
|
from packaging import version
|
|
|
|
from threestudio.utils.config import config_to_primitive
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
def parse_version(ver: str):
|
|
return version.parse(ver)
|
|
|
|
|
|
def get_rank():
|
|
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
|
# therefore LOCAL_RANK needs to be checked first
|
|
rank_keys = ("LOCAL_RANK", "RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
|
for key in rank_keys:
|
|
rank = os.environ.get(key)
|
|
if rank is not None:
|
|
return int(rank)
|
|
return 0
|
|
|
|
|
|
def get_device():
|
|
return torch.device(f"cuda:{get_rank()}")
|
|
|
|
|
|
def load_module_weights(
|
|
path, module_name=None, ignore_modules=None, map_location=None
|
|
) -> Tuple[dict, int, int]:
|
|
if module_name is not None and ignore_modules is not None:
|
|
raise ValueError("module_name and ignore_modules cannot be both set")
|
|
if map_location is None:
|
|
map_location = get_device()
|
|
|
|
ckpt = torch.load(path, map_location=map_location)
|
|
state_dict = ckpt["state_dict"]
|
|
state_dict_to_load = state_dict
|
|
|
|
if ignore_modules is not None:
|
|
state_dict_to_load = {}
|
|
for k, v in state_dict.items():
|
|
ignore = any(
|
|
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
|
|
)
|
|
if ignore:
|
|
continue
|
|
state_dict_to_load[k] = v
|
|
|
|
if module_name is not None:
|
|
state_dict_to_load = {}
|
|
for k, v in state_dict.items():
|
|
m = re.match(rf"^{module_name}\.(.*)$", k)
|
|
if m is None:
|
|
continue
|
|
state_dict_to_load[m.group(1)] = v
|
|
|
|
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
|
|
|
|
|
|
def C(value: Any, epoch: int, global_step: int) -> float:
|
|
if isinstance(value, int) or isinstance(value, float):
|
|
pass
|
|
else:
|
|
value = config_to_primitive(value)
|
|
if not isinstance(value, list):
|
|
raise TypeError("Scalar specification only supports list, got", type(value))
|
|
if len(value) == 3:
|
|
value = [0] + value
|
|
if len(value) >= 6:
|
|
select_i = 3
|
|
for i in range(3, len(value) - 2, 2):
|
|
if global_step >= value[i]:
|
|
select_i = i + 2
|
|
if select_i != 3:
|
|
start_value, start_step = value[select_i - 3], value[select_i - 2]
|
|
else:
|
|
start_step, start_value = value[:2]
|
|
end_value, end_step = value[select_i - 1], value[select_i]
|
|
value = [start_step, start_value, end_value, end_step]
|
|
assert len(value) == 4
|
|
start_step, start_value, end_value, end_step = value
|
|
if isinstance(end_step, int):
|
|
current_step = global_step
|
|
value = start_value + (end_value - start_value) * max(
|
|
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
|
|
)
|
|
elif isinstance(end_step, float):
|
|
current_step = epoch
|
|
value = start_value + (end_value - start_value) * max(
|
|
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
|
|
)
|
|
return value
|
|
|
|
|
|
def cleanup():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
tcnn.free_temporary_memory()
|
|
|
|
|
|
def finish_with_cleanup(func: Callable):
|
|
def wrapper(*args, **kwargs):
|
|
out = func(*args, **kwargs)
|
|
cleanup()
|
|
return out
|
|
|
|
return wrapper
|
|
|
|
|
|
def _distributed_available():
|
|
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
|
|
|
|
|
def barrier():
|
|
if not _distributed_available():
|
|
return
|
|
else:
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def broadcast(tensor, src=0):
|
|
if not _distributed_available():
|
|
return tensor
|
|
else:
|
|
torch.distributed.broadcast(tensor, src=src)
|
|
return tensor
|
|
|
|
|
|
def enable_gradient(model, enabled: bool = True) -> None:
|
|
for param in model.parameters():
|
|
param.requires_grad_(enabled)
|
|
|
|
def find_last_path(path: str):
|
|
if (path is not None) and ("LAST" in path):
|
|
path = path.replace(" ", "_")
|
|
base_dir_prefix, suffix = path.split("LAST", 1)
|
|
base_dir = os.path.dirname(base_dir_prefix)
|
|
prefix = os.path.split(base_dir_prefix)[-1]
|
|
base_dir_prefix = os.path.join(base_dir, prefix)
|
|
all_path = os.listdir(base_dir)
|
|
all_path = [os.path.join(base_dir, dir) for dir in all_path]
|
|
filtered_path = [dir for dir in all_path if dir.startswith(base_dir_prefix)]
|
|
filtered_path.sort(reverse=True)
|
|
last_path = filtered_path[0]
|
|
new_path = last_path + suffix
|
|
if os.path.exists(new_path):
|
|
return new_path
|
|
else:
|
|
raise FileNotFoundError(new_path)
|
|
else:
|
|
return path |