This commit is contained in:
revital 2023-05-16 15:57:27 +03:00
commit 813890bc07
16 changed files with 381 additions and 165 deletions

View File

@ -1158,24 +1158,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
str_value = str(value)
if isinstance(value, (tuple, list, dict)):
if 'None' in re.split(r'[ ,\[\]{}()]', str_value):
# If we have None in the string we have to use json to replace it with null,
# otherwise we end up with None as string when running remotely
try:
str_json = json.dumps(value)
# verify we actually have a null in the string, otherwise prefer the str cast
# This is because we prefer to have \' as in str and not \" used in json
if 'null' in re.split(r'[ ,\[\]{}()]', str_json):
return str_json
except TypeError:
# if we somehow failed to json serialize, revert to previous std casting
pass
elif any('\\' in str(v) for v in value):
try:
str_json = json.dumps(value)
return str_json
except TypeError:
pass
try:
str_json = json.dumps(value)
return str_json
except TypeError:
pass
if isinstance(value, Enum):
# remove the class name
@ -1920,10 +1907,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: http/s URL link.
"""
return '{}/projects/{}/experiments/{}/output/log'.format(
self._get_app_server(),
self.project if self.project is not None else '*',
self.id,
return self.get_task_output_log_web_page(
task_id=self.id,
project_id=self.project,
app_server_host=self._get_app_server()
)
def get_reported_scalars(
@ -2727,6 +2714,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
task = get_single_result(entity='task', query=task_name, results=res.response.tasks)
return cls(task_id=task.id)
@classmethod
def get_task_output_log_web_page(cls, task_id, project_id=None, app_server_host=None):
# type: (str, Optional[str], Optional[str]) -> str
"""
Return the Task results & outputs web page address.
For example: https://demoapp.demo.clear.ml/projects/216431/experiments/60763e04/output/log
:param str task_id: Task ID.
:param str project_id: Project ID for this task.
:param str app_server_host: ClearML Application server host name.
If not provided, the current session will be used to resolve the host name.
:return: http/s URL link.
"""
if not app_server_host:
if not hasattr(cls, "__cached_app_server_host"):
cls.__cached_app_server_host = Session.get_app_server_host()
app_server_host = cls.__cached_app_server_host
return "{}/projects/{}/experiments/{}/output/log".format(
app_server_host.rstrip("/"),
project_id if project_id is not None else '*',
task_id,
)
@classmethod
def _get_project_name(cls, project_id):
res = cls._send(cls._get_default_session(), projects.GetByIdRequest(project=project_id), raise_on_errors=False)

View File

@ -1,4 +1,5 @@
import json
import logging
try:
from jsonargparse import ArgumentParser
@ -102,7 +103,8 @@ class PatchJsonArgParse(object):
for k, v in params.items():
params_namespace[k] = v
return params_namespace
except Exception:
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
return original_fn(obj, **kwargs)
parsed_args = original_fn(obj, **kwargs)
# noinspection PyBroadException
@ -114,10 +116,14 @@ class PatchJsonArgParse(object):
PatchJsonArgParse._args_type[ns_name] = PatchJsonArgParse._command_type
subcommand = ns_val
try:
import pytorch_lightning
import lightning
except ImportError:
pytorch_lightning = None
if subcommand and subcommand in PatchJsonArgParse._args and pytorch_lightning:
try:
import pytorch_lightning
lightning = pytorch_lightning
except ImportError:
lightning = None
if subcommand and subcommand in PatchJsonArgParse._args and lightning:
subcommand_args = flatten_dictionary(
PatchJsonArgParse._args[subcommand],
prefix=subcommand + PatchJsonArgParse._commands_sep,
@ -127,8 +133,8 @@ class PatchJsonArgParse(object):
PatchJsonArgParse._args.update(subcommand_args)
PatchJsonArgParse._args = {k: v for k, v in PatchJsonArgParse._args.items()}
PatchJsonArgParse._update_task_args()
except Exception:
pass
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
return parsed_args
@staticmethod

View File

@ -205,6 +205,9 @@
# compatibility feature, report memory usage for the entire machine
# default (false), report only on the running process and its sub-processes
report_global_mem_used: false
# if provided, start resource reporting after this amount of seconds
#report_start_sec: 30
}
}

View File

@ -463,8 +463,10 @@ class Dataset(object):
dataset_paths = itertools.repeat(dataset_path)
else:
if len(dataset_path) != len(source_url):
raise ValueError("dataset_path must be a string or a list of strings with the same length as source_url"
f" (received {len(dataset_path)} paths for {len(source_url)} source urls))")
raise ValueError(
f"dataset_path must be a string or a list of strings with the same length as source_url"
f" (received {len(dataset_path)} paths for {len(source_url)} source urls))"
)
dataset_paths = dataset_path
with ThreadPoolExecutor(max_workers=max_workers) as tp:
for source_url_, dataset_path_ in zip(source_url_list, dataset_paths):

View File

@ -10,6 +10,7 @@ import platform
import shutil
import sys
import threading
import uuid
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
@ -1036,19 +1037,19 @@ class StorageHelper(object):
def check_write_permissions(self, dest_path=None):
# create a temporary file, then delete it
base_url = dest_path or self._base_url
dest_path = base_url + '/.clearml.test'
dest_path = base_url + "/.clearml.{}.test".format(str(uuid.uuid4()))
# do not check http/s connection permissions
if dest_path.startswith('http'):
if dest_path.startswith("http"):
return True
try:
self.upload_from_stream(stream=six.BytesIO(b'clearml'), dest_path=dest_path)
self.upload_from_stream(stream=six.BytesIO(b"clearml"), dest_path=dest_path)
except Exception:
raise ValueError('Insufficient permissions (write failed) for {}'.format(base_url))
raise ValueError("Insufficient permissions (write failed) for {}".format(base_url))
try:
self.delete(path=dest_path)
except Exception:
raise ValueError('Insufficient permissions (delete failed) for {}'.format(base_url))
raise ValueError("Insufficient permissions (delete failed) for {}".format(base_url))
return True
@classmethod

View File

@ -1,4 +1,5 @@
import atexit
import copy
import json
import os
import shutil
@ -100,6 +101,7 @@ from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.matching import matches_any_wildcard
from .utilities.parallel import FutureTaskCaller
from .utilities.networking import get_private_ip
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
@ -174,6 +176,9 @@ class Task(_Task):
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
_launch_multi_node_section = "launch_multi_node"
_launch_multi_node_instance_tag = "multi_node_instance"
class _ConnectedParametersType(object):
argparse = "argument_parser"
dictionary = "dictionary"
@ -702,8 +707,10 @@ class Task(_Task):
resource_monitor_cls = auto_resource_monitoring \
if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor
task._resource_monitor = resource_monitor_cls(
task, report_mem_used_per_process=not config.get(
'development.worker.report_global_mem_used', False))
task,
report_mem_used_per_process=not config.get('development.worker.report_global_mem_used', False),
first_report_sec=config.get('development.worker.report_start_sec', None),
)
task._resource_monitor.start()
# make sure all random generators are initialized with new seed
@ -1025,19 +1032,18 @@ class Task(_Task):
:return: The Tasks specified by the parameter combinations (see the parameters).
"""
task_filter = task_filter or {}
if tags:
task_filter = task_filter or {}
task_filter['tags'] = (task_filter.get('tags') or []) + list(tags)
return_fields = {}
if additional_return_fields:
task_filter = task_filter or {}
return_fields = set(list(additional_return_fields) + ['id'])
task_filter['only_fields'] = (task_filter.get('only_fields') or []) + list(return_fields)
if task_filter.get('type'):
task_filter['type'] = [str(task_type) for task_type in task_filter['type']]
results = cls._query_tasks(project_name=project_name, task_name=task_name, **(task_filter or {}))
results = cls._query_tasks(project_name=project_name, task_name=task_name, **task_filter)
return [t.id for t in results] if not additional_return_fields else \
[{k: cls._get_data_property(prop_path=k, data=r, raise_on_error=False, log_on_error=False)
for k in return_fields}
@ -1637,6 +1643,144 @@ class Task(_Task):
"""
return self._get_logger(auto_connect_streams=self._log_to_backend)
def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
"""
Enqueue multiple clones of the current task to a queue, allowing the task
to be ran by multiple workers in parallel. Each task running this way is called a node.
Each node has a rank The node that initialized the execution of the other nodes
is called the `master node` and it has a rank equal to 0.
A dictionary named `multi_node_instance` will be connected to the tasks.
One can use this dictionary to modify the behaviour of this function when running remotely.
The contents of this dictionary correspond to the parameters of this function, and they are:
- `total_num_nodes` - the total number of nodes, including the master node
- `queue` - the queue to enqueue the nodes to
The following environment variables, will be set:
- `MASTER_ADDR` - the address of the machine that the master node is running on
- `MASTER_PORT` - the open port of the machine that the master node is running on
- `WORLD_SIZE` - the total number of nodes, including the master
- `RANK` - the rank of the current node (master has rank 0)
One may use this function in conjuction with PyTorch's distributed communication package.
Note that `Task.launch_multi_node` should be called before `torch.distributed.init_process_group`.
For example:
.. code-block:: py
from clearml import Task
import torch
import torch.distributed as dist
def run(rank, size):
print('World size is ', size)
tensor = torch.zeros(1)
if rank == 0:
for i in range(1, size):
tensor += 1
dist.send(tensor=tensor, dst=i)
print('Sending from rank ', rank, ' to rank ', i, ' data: ', tensor[0])
else:
dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ' received data: ', tensor[0])
if __name__ == '__main__':
task = Task.init('some_name', 'some_name')
task.execute_remotely(queue_name='queue')
config = task.launch_multi_node(4)
dist.init_process_group('gloo')
run(config.get('node_rank'), config.get('total_num_nodes'))
:param total_num_nodes: The total number of nodes to be enqueued, including the master node,
which should already be enqueued when running remotely
:param port: Port opened by the master node. If the environment variable `CLEARML_MULTI_NODE_MASTER_DEF_PORT`
is set, the value of this parameter will be set to the one defined in `CLEARML_MULTI_NODE_MASTER_DEF_PORT`.
If `CLEARML_MULTI_NODE_MASTER_DEF_PORT` doesn't exist, but `MASTER_PORT` does, then the value of this
parameter will be set to the one defined in `MASTER_PORT`. If neither environment variables exist,
the value passed to the parameter will be used
:param queue: The queue to enqueue the nodes to. Can be different than the queue the master
node is enqueued to. If None, the nodes will be enqueued to the same queue as the master node
:param wait: If True, the master node will wait for the other nodes to start
:param addr: The address of the master node's worker. If not set, it defaults to the private IP
of the machine the master is running on
:return: A dictionary containing relevant information regarding the multi node run. This dictionary
has the following entries:
- `master_addr` - the address of the machine that the master node is running on
- `master_port` - the open port of the machine that the master node is running on
- `total_num_nodes` - the total number of nodes, including the master
- `queue` - the queue the nodes are enqueued to, excluding the master
- `node_rank` - the rank of the current node (master has rank 0)
- `wait` - if True, the master node will wait for the other nodes to start
"""
def set_launch_multi_node_runtime_props(task, conf):
# noinspection PyProtectedMember
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
if total_num_nodes < 1:
raise UsageError("total_num_nodes needs to be at least 1")
if running_remotely() and not (self.data.execution and self.data.execution.queue) and not queue:
raise UsageError("Master task is not enqueued to any queue and the queue parameter is None")
master_conf = {
"master_addr": get_private_ip(),
"master_port": int(os.environ.get("CLEARML_MULTI_NODE_MASTER_DEF_PORT", os.environ.get("MASTER_PORT", port))),
"node_rank": 0,
"wait": wait
}
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
if not running_remotely():
return master_conf
master_conf.update(editable_conf)
runtime_properties = self._get_runtime_properties()
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
if remote_node_rank:
# self is a child node, build the conf from the runtime proprerties
current_conf = {
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
for entry in master_conf.keys()
}
else:
nodes_to_wait = []
# self is the master node, enqueue the other nodes
set_launch_multi_node_runtime_props(self, master_conf)
current_conf = master_conf
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
node = self.clone(source_task=self)
node_conf = copy.deepcopy(master_conf)
node_conf["node_rank"] = node_rank
set_launch_multi_node_runtime_props(node, node_conf)
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
if master_conf.get("queue"):
Task.enqueue(node, queue_name=master_conf["queue"])
else:
Task.enqueue(node, queue_id=self.data.execution.queue)
if master_conf.get("wait"):
nodes_to_wait.append(node)
for node_to_wait, rank in zip(nodes_to_wait, range(1, master_conf.get("total_num_nodes", total_num_nodes))):
self.log.info("Waiting for node with task ID {} and rank {}".format(node_to_wait.id, rank))
node_to_wait.wait_for_status(
status=(
Task.TaskStatusEnum.completed,
Task.TaskStatusEnum.stopped,
Task.TaskStatusEnum.closed,
Task.TaskStatusEnum.failed,
Task.TaskStatusEnum.in_progress
),
check_interval_sec=10
)
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
return current_conf
def mark_started(self, force=False):
# type: (bool) -> ()
"""
@ -4506,3 +4650,14 @@ class Task(_Task):
auto_connect_frameworks={'detect_repository': False}) \
if state['main'] else Task.get_task(task_id=state['id'])
self.__dict__ = task.__dict__
def __getattr__(self, name):
try:
self.__getattribute__(name)
except AttributeError as e:
if self.__class__ is Task:
getLogger().warning(
"'clearml.Task' object has no attribute '{}'. Did you mean to import 'Task' from 'allegroai'?".format(name)
)
raise e

View File

@ -1 +1 @@
__version__ = '1.10.3'
__version__ = '1.10.4'

View File

@ -32,11 +32,11 @@ params = {
params = task.connect(params) # enabling configuration override by clearml/
print(params) # printing actual configuration (after override in remote mode)
# The below gets the dataset and stores in the cache. If you want to download the dataset regardless if it's in the
# cache, use the Dataset.get(dataset_name, dataset_project).get_mutable_local_copy(path to download)
# Dataset need to have finalized or closed state to get the local copy of it
dataset_path = Dataset.get(
dataset_name=dataset_name, dataset_project=dataset_project
dataset_name=dataset_name, dataset_project=dataset_project, only_completed=False
).get_local_copy()
# Dataset and Dataloader initializations

View File

@ -1,103 +1,14 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Notice that this file has been modified to examplify the use of
# ClearML when used with PyTorch Lightning
import torch
import torchvision.transforms as T
from torch.nn import functional as F
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision.datasets.mnist import MNIST
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
try:
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
except ImportError:
import sys
print("Module 'lightning' not installed (only available for Python 3.8+")
sys.exit(0)
from clearml import Task
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class ImageClassifier(LightningModule):
def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
super().__init__()
self.save_hyperparameters(ignore="model")
self.model = model or Net()
self.test_acc = Accuracy()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
self.test_acc(logits, y)
self.log("test_acc", self.test_acc)
self.log("test_loss", loss)
def configure_optimizers(self):
optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]
@property
def transform(self):
return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
def prepare_data(self) -> None:
MNIST("./data", download=True)
def train_dataloader(self):
train_dataset = MNIST("./data", train=True, download=False, transform=self.transform)
return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size)
def test_dataloader(self):
test_dataset = MNIST("./data", train=False, download=False, transform=self.transform)
return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size)
if __name__ == "__main__":
Task.add_requirements('requirements.txt')
task = Task.init(project_name="example", task_name="pytorch_lightning_jsonargparse")
LightningCLI(ImageClassifier, seed_everything_default=42, save_config_overwrite=True, run=True)
Task.add_requirements("requirements.txt")
Task.init(project_name="example", task_name="pytorch_lightning_jsonargparse")
LightningCLI(DemoModel, BoringDataModule)

View File

@ -1,12 +1,13 @@
trainer:
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: False
save_top_k: 1
monitor: loss
mode: min
max_epochs: 10

View File

@ -0,0 +1,114 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Notice that this file has been modified to examplify the use of
# ClearML when used with PyTorch Lightning
import sys
import torch
import torchvision.transforms as T
from torch.nn import functional as F
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision.datasets.mnist import MNIST
from pytorch_lightning import LightningModule
from clearml import Task
try:
from pytorch_lightning.cli import LightningCLI
except ImportError:
try:
from pytorch_lightning.utilities.cli import LightningCLI
except ImportError:
print("Looks like you are using pytorch_lightning>=2.0. This example only works with older versions")
sys.exit(0)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class ImageClassifier(LightningModule):
def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
super().__init__()
self.save_hyperparameters(ignore="model")
self.model = model or Net()
try:
self.test_acc = Accuracy()
except TypeError:
self.test_acc = Accuracy("binary")
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
self.test_acc(logits, y)
self.log("test_acc", self.test_acc)
self.log("test_loss", loss)
def configure_optimizers(self):
optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]
@property
def transform(self):
return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
def prepare_data(self) -> None:
MNIST("./data", download=True)
def train_dataloader(self):
train_dataset = MNIST("./data", train=True, download=False, transform=self.transform)
return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size)
def test_dataloader(self):
test_dataset = MNIST("./data", train=False, download=False, transform=self.transform)
return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size)
if __name__ == "__main__":
Task.add_requirements("requirements.txt")
Task.init(project_name="example", task_name="pytorch_lightning_jsonargparse")
LightningCLI(ImageClassifier, seed_everything_default=42, run=True)

View File

@ -0,0 +1,12 @@
trainer:
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: False
save_top_k: 1
monitor: loss
mode: min

View File

@ -1,7 +1,8 @@
clearml
jsonargparse
pytorch_lightning
torch
torchmetrics
torchvision
docstring_parser
pytorch-lightning[extra]
lightning; python_version >= '3.8'

View File

@ -3,12 +3,11 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from mmcv.parallel import MMDataParallel
from mmcv.runner import EpochBasedRunner
from mmcv.utils import get_logger
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
class Model(nn.Module):

View File

@ -1,4 +1,4 @@
clearml
mmcv>=1.5.1
mmcv>=1.5.1,<2.0.0
torch
torchvision

View File

@ -1,13 +1,14 @@
import os
import sys
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from clearml import Task
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from torchvision.datasets.mnist import MNIST
from clearml import Task
class LitClassifier(pl.LightningModule):
@ -35,12 +36,13 @@ class LitClassifier(pl.LightningModule):
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@ -54,19 +56,17 @@ class LitClassifier(pl.LightningModule):
if __name__ == '__main__':
# Connecting ClearML with the current process,
# from here on everything is logged automatically
task = Task.init(project_name="examples", task_name="PyTorch lightning MNIST example")
pl.seed_everything(0)
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(max_epochs=3)
parser.add_argument('--max_epochs', default=3, type=int)
sys.argv.extend(['--max_epochs', '2'])
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()
Task.init(project_name="examples-internal", task_name="lightning checkpoint issue and argparser")
# ------------
# data
# ------------
@ -74,9 +74,9 @@ if __name__ == '__main__':
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size, num_workers=os.cpu_count())
val_loader = DataLoader(mnist_val, batch_size=args.batch_size, num_workers=os.cpu_count())
test_loader = DataLoader(mnist_test, batch_size=args.batch_size, num_workers=os.cpu_count())
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
@ -86,7 +86,7 @@ if __name__ == '__main__':
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer = pl.Trainer(max_epochs=args.max_epochs)
trainer.fit(model, train_loader, val_loader)
# ------------