mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-04 18:15:11 +00:00
156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
import os
|
|
import shutil
|
|
import subprocess
|
|
|
|
import pytorch_lightning
|
|
|
|
from threestudio.utils.config import dump_config
|
|
from threestudio.utils.misc import parse_version
|
|
|
|
if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
|
|
from pytorch_lightning.callbacks import Callback
|
|
else:
|
|
from pytorch_lightning.callbacks.base import Callback
|
|
|
|
from pytorch_lightning.callbacks.progress import TQDMProgressBar
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
|
|
|
|
|
class VersionedCallback(Callback):
|
|
def __init__(self, save_root, version=None, use_version=True):
|
|
self.save_root = save_root
|
|
self._version = version
|
|
self.use_version = use_version
|
|
|
|
@property
|
|
def version(self) -> int:
|
|
"""Get the experiment version.
|
|
|
|
Returns:
|
|
The experiment version if specified else the next version.
|
|
"""
|
|
if self._version is None:
|
|
self._version = self._get_next_version()
|
|
return self._version
|
|
|
|
def _get_next_version(self):
|
|
existing_versions = []
|
|
if os.path.isdir(self.save_root):
|
|
for f in os.listdir(self.save_root):
|
|
bn = os.path.basename(f)
|
|
if bn.startswith("version_"):
|
|
dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
|
|
existing_versions.append(int(dir_ver))
|
|
if len(existing_versions) == 0:
|
|
return 0
|
|
return max(existing_versions) + 1
|
|
|
|
@property
|
|
def savedir(self):
|
|
if not self.use_version:
|
|
return self.save_root
|
|
return os.path.join(
|
|
self.save_root,
|
|
self.version
|
|
if isinstance(self.version, str)
|
|
else f"version_{self.version}",
|
|
)
|
|
|
|
|
|
class CodeSnapshotCallback(VersionedCallback):
|
|
def __init__(self, save_root, version=None, use_version=True):
|
|
super().__init__(save_root, version, use_version)
|
|
|
|
def get_file_list(self):
|
|
return [
|
|
b.decode()
|
|
for b in set(
|
|
subprocess.check_output(
|
|
'git ls-files -- ":!:load/*"', shell=True
|
|
).splitlines()
|
|
)
|
|
| set( # hard code, TODO: use config to exclude folders or files
|
|
subprocess.check_output(
|
|
"git ls-files --others --exclude-standard", shell=True
|
|
).splitlines()
|
|
)
|
|
]
|
|
|
|
@rank_zero_only
|
|
def save_code_snapshot(self):
|
|
os.makedirs(self.savedir, exist_ok=True)
|
|
for f in self.get_file_list():
|
|
if not os.path.exists(f) or os.path.isdir(f):
|
|
continue
|
|
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
|
|
shutil.copyfile(f, os.path.join(self.savedir, f))
|
|
|
|
def on_fit_start(self, trainer, pl_module):
|
|
try:
|
|
self.save_code_snapshot()
|
|
except:
|
|
rank_zero_warn(
|
|
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
|
|
)
|
|
|
|
|
|
class ConfigSnapshotCallback(VersionedCallback):
|
|
def __init__(self, config_path, config, save_root, version=None, use_version=True):
|
|
super().__init__(save_root, version, use_version)
|
|
self.config_path = config_path
|
|
self.config = config
|
|
|
|
@rank_zero_only
|
|
def save_config_snapshot(self):
|
|
os.makedirs(self.savedir, exist_ok=True)
|
|
dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
|
|
shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
|
|
|
|
def on_fit_start(self, trainer, pl_module):
|
|
self.save_config_snapshot()
|
|
|
|
|
|
class CustomProgressBar(TQDMProgressBar):
|
|
def get_metrics(self, *args, **kwargs):
|
|
# don't show the version number
|
|
items = super().get_metrics(*args, **kwargs)
|
|
items.pop("v_num", None)
|
|
return items
|
|
|
|
|
|
class ProgressCallback(Callback):
|
|
def __init__(self, save_path):
|
|
super().__init__()
|
|
self.save_path = save_path
|
|
self._file_handle = None
|
|
|
|
@property
|
|
def file_handle(self):
|
|
if self._file_handle is None:
|
|
self._file_handle = open(self.save_path, "w")
|
|
return self._file_handle
|
|
|
|
@rank_zero_only
|
|
def write(self, msg: str) -> None:
|
|
self.file_handle.seek(0)
|
|
self.file_handle.truncate()
|
|
self.file_handle.write(msg)
|
|
self.file_handle.flush()
|
|
|
|
@rank_zero_only
|
|
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
|
|
self.write(
|
|
f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
|
|
)
|
|
|
|
@rank_zero_only
|
|
def on_validation_start(self, trainer, pl_module):
|
|
self.write(f"Rendering validation image ...")
|
|
|
|
@rank_zero_only
|
|
def on_test_start(self, trainer, pl_module):
|
|
self.write(f"Rendering video ...")
|
|
|
|
@rank_zero_only
|
|
def on_predict_start(self, trainer, pl_module):
|
|
self.write(f"Exporting mesh assets ...") |