DreamCraft3D/threestudio/systems/utils.py
2023-12-15 17:44:44 +08:00

104 lines
3.0 KiB
Python

import sys
import warnings
from bisect import bisect_right
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import threestudio
def get_scheduler(name):
if hasattr(lr_scheduler, name):
return getattr(lr_scheduler, name)
else:
raise NotImplementedError
def getattr_recursive(m, attr):
for name in attr.split("."):
m = getattr(m, name)
return m
def get_parameters(model, name):
module = getattr_recursive(model, name)
if isinstance(module, nn.Module):
return module.parameters()
elif isinstance(module, nn.Parameter):
return module
return []
def parse_optimizer(config, model):
if hasattr(config, "params"):
params = [
{"params": get_parameters(model, name), "name": name, **args}
for name, args in config.params.items()
]
threestudio.debug(f"Specify optimizer params: {config.params}")
else:
params = model.parameters()
if config.name in ["FusedAdam"]:
import apex
optim = getattr(apex.optimizers, config.name)(params, **config.args)
elif config.name in ["Adan"]:
from threestudio.systems import optimizers
optim = getattr(optimizers, config.name)(params, **config.args)
else:
optim = getattr(torch.optim, config.name)(params, **config.args)
return optim
def parse_scheduler_to_instance(config, optimizer):
if config.name == "ChainedScheduler":
schedulers = [
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
]
scheduler = lr_scheduler.ChainedScheduler(schedulers)
elif config.name == "Sequential":
schedulers = [
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
]
scheduler = lr_scheduler.SequentialLR(
optimizer, schedulers, milestones=config.milestones
)
else:
scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args)
return scheduler
def parse_scheduler(config, optimizer):
interval = config.get("interval", "epoch")
assert interval in ["epoch", "step"]
if config.name == "SequentialLR":
scheduler = {
"scheduler": lr_scheduler.SequentialLR(
optimizer,
[
parse_scheduler(conf, optimizer)["scheduler"]
for conf in config.schedulers
],
milestones=config.milestones,
),
"interval": interval,
}
elif config.name == "ChainedScheduler":
scheduler = {
"scheduler": lr_scheduler.ChainedScheduler(
[
parse_scheduler(conf, optimizer)["scheduler"]
for conf in config.schedulers
]
),
"interval": interval,
}
else:
scheduler = {
"scheduler": get_scheduler(config.name)(optimizer, **config.args),
"interval": interval,
}
return scheduler