clearml-docs/docs/guides/distributed/distributed_pytorch_example.md
2023-01-12 12:49:55 +02:00

3.6 KiB

title
PyTorch Distributed

The pytorch_distributed_example.py script demonstrates integrating ClearML into code that uses the PyTorch Distributed Communications Package (torch.distributed).

The script initializes a main Task and spawns subprocesses, each for an instance of that Task. The Task in each subprocess trains a neural network over a partitioned dataset (the torchvision built-in MNIST dataset), and reports (uploads) the following to the main Task:

  • Artifacts - A dictionary containing different key-value pairs.
  • Scalars - Loss reported as a scalar during training in each Task in a subprocess.
  • Hyperparameters - Hyperparameters created in each Task are added to the hyperparameters in the main Task.

Each Task in a subprocess references the main Task by calling Task.current_task, which always returns the main Task.

When the script runs, it creates an experiment named test torch distributed, which is associated with the examples project.

Artifacts

The example uploads a dictionary as an artifact in the main Task by calling the Task.upload_artifact method on Task.current_task (the main Task). The dictionary contains the dist.rank of the subprocess, making each unique.

Task.current_task().upload_artifact(
    'temp {:02d}'.format(dist.get_rank()), 
    artifact_object={'worker_rank': dist.get_rank()}
)

All of these artifacts appear in the main Task under ARTIFACTS > OTHER.

Experiment artifacts

Scalars

Loss is reported to the main Task by calling the Logger.report_scalar method on Task.current_task().get_logger, which is the logger for the main Task. Since Logger.report_scalar is called with the same title (loss), but a different series name (containing the subprocess' rank), all loss scalar series are logged together.

Task.current_task().get_logger().report_scalar(
    'loss', 
    'worker {:02d}'.format(dist.get_rank()), 
    value=loss.item(), 
    iteration=i
)

The single scalar plot for loss appears in SCALARS.

Experiment scalars

Hyperparameters

ClearML automatically logs the argparse command line options. Since the Task.connect method is called on Task.current_task, they are logged in the main Task. A different hyperparameter key is used in each subprocess, so they do not overwrite each other in the main Task.

param = {'worker_{}_stuff'.format(dist.get_rank()): 'some stuff ' + str(randint(0, 100))}
Task.current_task().connect(param)

All the hyperparameters appear in CONFIGURATION > HYPERPARAMETERS.

Experiment hyperparameters Args

Experiment hyperparameters General

Console

Output to the console, including the text messages printed from the main Task object and each subprocess appear in CONSOLE.

Experiment console log