mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Fix jsonargparse support (#403)
This commit is contained in:
parent
ae734c81e7
commit
95785e7637
@ -1,5 +1,5 @@
|
||||
import ast
|
||||
import copy
|
||||
import six
|
||||
|
||||
try:
|
||||
from jsonargparse import ArgumentParser
|
||||
@ -9,6 +9,7 @@ except ImportError:
|
||||
|
||||
from ..config import running_remotely, get_remote_task_id
|
||||
from .frameworks import _patched_call # noqa
|
||||
from ..utilities.proxy_object import flatten_dictionary
|
||||
|
||||
|
||||
class PatchJsonArgParse(object):
|
||||
@ -55,38 +56,45 @@ class PatchJsonArgParse(object):
|
||||
try:
|
||||
PatchJsonArgParse._load_task_params()
|
||||
params = PatchJsonArgParse.__remote_task_params_dict
|
||||
params_namespace = Namespace()
|
||||
for k, v in params.items():
|
||||
if v == '':
|
||||
if v == "":
|
||||
v = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = ast.literal_eval(v)
|
||||
except Exception:
|
||||
pass
|
||||
params[k] = v
|
||||
params = PatchJsonArgParse.__unflatten_dict(params)
|
||||
params = PatchJsonArgParse.__nested_dict_to_namespace(params)
|
||||
return params
|
||||
params_namespace[k] = PatchJsonArgParse.__namespace_eval(v)
|
||||
return params_namespace
|
||||
except Exception:
|
||||
return original_fn(obj, **kwargs)
|
||||
orig_parsed_args = original_fn(obj, **kwargs)
|
||||
parsed_args = original_fn(obj, **kwargs)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed_args = vars(copy.deepcopy(orig_parsed_args))
|
||||
for ns_name, ns_val in parsed_args.items():
|
||||
if not isinstance(ns_val, (Namespace, list)):
|
||||
PatchJsonArgParse._args[ns_name] = str(ns_val)
|
||||
if ns_name == PatchJsonArgParse._command_name:
|
||||
PatchJsonArgParse._args_type[ns_name] = PatchJsonArgParse._command_type
|
||||
else:
|
||||
ns_val = PatchJsonArgParse.__nested_namespace_to_dict(ns_val)
|
||||
ns_val = PatchJsonArgParse.__flatten_dict(ns_val, parent_name=ns_name)
|
||||
for k, v in ns_val.items():
|
||||
PatchJsonArgParse._args[k] = str(v)
|
||||
subcommand = None
|
||||
for ns_name, ns_val in Namespace(parsed_args).items():
|
||||
PatchJsonArgParse._args[ns_name] = ns_val
|
||||
if ns_name == PatchJsonArgParse._command_name:
|
||||
PatchJsonArgParse._args_type[ns_name] = PatchJsonArgParse._command_type
|
||||
subcommand = ns_val
|
||||
try:
|
||||
import pytorch_lightning
|
||||
except ImportError:
|
||||
pytorch_lightning = None
|
||||
if subcommand and subcommand in PatchJsonArgParse._args and pytorch_lightning:
|
||||
subcommand_args = flatten_dictionary(
|
||||
PatchJsonArgParse._args[subcommand],
|
||||
prefix=subcommand + PatchJsonArgParse._commands_sep,
|
||||
sep=PatchJsonArgParse._commands_sep,
|
||||
)
|
||||
del PatchJsonArgParse._args[subcommand]
|
||||
PatchJsonArgParse._args.update(subcommand_args)
|
||||
PatchJsonArgParse._args = {k: str(v) for k, v in PatchJsonArgParse._args.items()}
|
||||
PatchJsonArgParse._update_task_args()
|
||||
except Exception:
|
||||
pass
|
||||
return orig_parsed_args
|
||||
return parsed_args
|
||||
|
||||
@staticmethod
|
||||
def _load_task_params():
|
||||
@ -105,62 +113,15 @@ class PatchJsonArgParse(object):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def __nested_namespace_to_dict(namespace):
|
||||
if isinstance(namespace, list):
|
||||
return [PatchJsonArgParse.__nested_namespace_to_dict(n) for n in namespace]
|
||||
if not isinstance(namespace, Namespace):
|
||||
return namespace
|
||||
namespace = vars(namespace)
|
||||
for k, v in namespace.items():
|
||||
namespace[k] = PatchJsonArgParse.__nested_namespace_to_dict(v)
|
||||
return namespace
|
||||
|
||||
@staticmethod
|
||||
def __nested_dict_to_namespace(dict_):
|
||||
if isinstance(dict_, list):
|
||||
return [PatchJsonArgParse.__nested_dict_to_namespace(d) for d in dict_]
|
||||
if not isinstance(dict_, dict):
|
||||
return dict_
|
||||
for k, v in dict_.items():
|
||||
dict_[k] = PatchJsonArgParse.__nested_dict_to_namespace(v)
|
||||
return Namespace(**dict_)
|
||||
|
||||
@staticmethod
|
||||
def __flatten_dict(dict_, parent_name=None):
|
||||
if isinstance(dict_, list):
|
||||
if parent_name:
|
||||
return {parent_name: [PatchJsonArgParse.__flatten_dict(d) for d in dict_]}
|
||||
return [PatchJsonArgParse.__flatten_dict(d) for d in dict_]
|
||||
if not isinstance(dict_, dict):
|
||||
if parent_name:
|
||||
return {parent_name: dict_}
|
||||
return dict_
|
||||
result = {}
|
||||
for k, v in dict_.items():
|
||||
v = PatchJsonArgParse.__flatten_dict(v, parent_name=k)
|
||||
if isinstance(v, dict):
|
||||
for flattened_k, flattened_v in v.items():
|
||||
if parent_name:
|
||||
result[parent_name + PatchJsonArgParse._commands_sep + flattened_k] = flattened_v
|
||||
else:
|
||||
result[flattened_k] = flattened_v
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __unflatten_dict(dict_):
|
||||
if isinstance(dict_, list):
|
||||
return [PatchJsonArgParse.__unflatten_dict(d) for d in dict_]
|
||||
if not isinstance(dict_, dict):
|
||||
return dict_
|
||||
result = {}
|
||||
for k, v in dict_.items():
|
||||
keys = k.split(PatchJsonArgParse._commands_sep)
|
||||
current_dict = result
|
||||
for k_part in keys[:-1]:
|
||||
if k_part not in current_dict:
|
||||
current_dict[k_part] = {}
|
||||
current_dict = current_dict[k_part]
|
||||
current_dict[keys[-1]] = PatchJsonArgParse.__unflatten_dict(v)
|
||||
return result
|
||||
def __namespace_eval(val):
|
||||
if isinstance(val, six.string_types) and val.startswith("Namespace(") and val[-1] == ")":
|
||||
val = val[len("Namespace("):]
|
||||
val = val[:-1]
|
||||
return Namespace(PatchJsonArgParse.__namespace_eval(ast.literal_eval("{" + val + "}")))
|
||||
if isinstance(val, list):
|
||||
return [PatchJsonArgParse.__namespace_eval(v) for v in val]
|
||||
if isinstance(val, dict):
|
||||
for k, v in val.items():
|
||||
val[k] = PatchJsonArgParse.__namespace_eval(v)
|
||||
return val
|
||||
return val
|
||||
|
@ -328,7 +328,8 @@ class Task(_Task):
|
||||
`ClearML Python Client Extras <./references/clearml_extras_storage/>`_ in the "ClearML Python Client
|
||||
Reference" section.
|
||||
|
||||
:param auto_connect_arg_parser: Automatically connect an argparse object to the Task
|
||||
:param auto_connect_arg_parser: Automatically connect an argparse object to the Task. Supported argument
|
||||
parsers packages are: argparse, click, python-fire, jsonargparse.
|
||||
|
||||
The values are:
|
||||
|
||||
|
@ -11,5 +11,5 @@ class Main:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="jsonargparse command", auto_connect_frameworks={"pytorch_lightning": False})
|
||||
Task.init(project_name="examples", task_name="jsonargparse command")
|
||||
print(CLI(Main))
|
@ -10,7 +10,7 @@ class Arg2:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="jsonargparse nested namespaces", auto_connect_frameworks={"pytorch-lightning": False})
|
||||
Task.init(project_name="examples", task_name="jsonargparse nested namespaces")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--arg1.opt1", default="from default 1")
|
||||
parser.add_argument("--arg1.opt2", default="from default 2")
|
103
examples/frameworks/jsonargparse/pytorch_lightning_cli.py
Normal file
103
examples/frameworks/jsonargparse/pytorch_lightning_cli.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Notice that this file has been modified to examplify the use of
|
||||
# ClearML when used with PyTorch Lightning
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from torchvision.datasets.mnist import MNIST
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from clearml import Task
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||
self.dropout1 = nn.Dropout(0.25)
|
||||
self.dropout2 = nn.Dropout(0.5)
|
||||
self.fc1 = nn.Linear(9216, 128)
|
||||
self.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.dropout1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
class ImageClassifier(LightningModule):
|
||||
def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore="model")
|
||||
self.model = model or Net()
|
||||
self.test_acc = Accuracy()
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.forward(x)
|
||||
loss = F.nll_loss(logits, y.long())
|
||||
return loss
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.forward(x)
|
||||
loss = F.nll_loss(logits, y.long())
|
||||
self.test_acc(logits, y)
|
||||
self.log("test_acc", self.test_acc)
|
||||
self.log("test_loss", loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
|
||||
return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
MNIST("./data", download=True)
|
||||
|
||||
def train_dataloader(self):
|
||||
train_dataset = MNIST("./data", train=True, download=False, transform=self.transform)
|
||||
return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
test_dataset = MNIST("./data", train=False, download=False, transform=self.transform)
|
||||
return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.add_requirements('requirements.txt')
|
||||
task = Task.init(project_name="example", task_name="pytorch_lightning_jsonargparse")
|
||||
LightningCLI(ImageClassifier, seed_everything_default=42, save_config_overwrite=True, run=True)
|
12
examples/frameworks/jsonargparse/pytorch_lightning_cli.yml
Normal file
12
examples/frameworks/jsonargparse/pytorch_lightning_cli.yml
Normal file
@ -0,0 +1,12 @@
|
||||
trainer:
|
||||
callbacks:
|
||||
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
||||
init_args:
|
||||
logging_interval: epoch
|
||||
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
init_args:
|
||||
filename: best
|
||||
save_last: False
|
||||
save_top_k: 1
|
||||
monitor: loss
|
||||
mode: min
|
@ -1,2 +1,7 @@
|
||||
clearml
|
||||
jsonargparse
|
||||
jsonargparse
|
||||
pytorch_lightning
|
||||
torch
|
||||
torchmetrics
|
||||
torchvision
|
||||
docstring_parser
|
||||
|
Loading…
Reference in New Issue
Block a user