Iterate on your code locally using the PyTorch Trainer API

Determined is a powerful platform for training models at scale. But before launching a huge training job, you typically need to go through an iterative process of testing and debugging. This is where the PyTorch Trainer API comes in.

The PyTorch Trainer API allows you to:

  • Easily refine and test your model training code locally.
  • Debug your model using your preferred environment, be it your machine, an IDE, or a Jupyter notebook.

Let’s see how it works with the PyTorch MNIST example.

Local Training

First, load the yaml config file using the PyYAML library.

import determined as det
import logging
import yaml
from model_def import MNistTrial

logging.getLogger().setLevel(logging.INFO)

# load the yaml config file
with open("const.yaml", "r") as file:
    exp_conf = yaml.safe_load(file)

hparams = exp_conf["hyperparameters"]

(If you don’t have PyYAML you can install it using pip: pip install PyYAML.)

Next, create a PyTorchTrialContext and MNISTTrial, and pass these into the Trainer constructor:

with det.pytorch.init(hparams=hparams, exp_conf=exp_conf) as train_context:
    trial = MNistTrial(train_context)
    trainer = det.pytorch.Trainer(trial, train_context)

Then call the fit() function to train the model:

    trainer.fit(max_length=det.pytorch.Epoch(1))

You can run this code from the command line (python train.py), or in a Jupyter notebook.

The above snippet trains for one epoch. You can customize the Trainer’s behavior by passing in various other arguments. For example, you can train for one epoch, while saving a checkpoint and running validation every 100 batches:

    trainer.fit(
        max_length=det.pytorch.Epoch(1),
        checkpoint_period=det.pytorch.Batch(100),
        validation_period=det.pytorch.Batch(100),
        checkpoint_policy="all"
    )

Sometimes you just want to see if your code will run at all. In this case, set test_mode=True to train and validate for just one batch:

    trainer.fit(
        max_length=det.pytorch.Epoch(1),
        test_mode=True
    )

From Local to Cluster

At some point, you’ll finish debugging your code locally, and you’ll want to train on a cluster. You may be wondering how many lines of code you need to change to do this.

The answer is… zero! All you have to do is set the yaml config’s entrypoint to point to your script:

entrypoint: python3 -m determined.launch.torch_distributed python3 train.py

Then in your command line run det e create const.yaml . and the training job will be running on your cluster.

Switching Between Distributed Local and Cluster Training

There is some extra boilerplate code required to do distributed training locally. But this boilerplate isn’t necessary when training on the cluster. So to make our distributed script compatible with both local and cluster mode, we can add the following snippet:

import os
from torch import distributed as dist

local = det.get_cluster_info() is None
if local:
    dist.init_process_group(backend="gloo")
    os.environ["USE_TORCH_DISTRIBUTED"] = "true"
    distributed_context = det.core.DistributedContext.from_torch_distributed()
    latest_checkpoint = None
else:
    distributed_context = None
    latest_checkpoint = det.get_cluster_info().latest_checkpoint

Pass the distributed_context into the context initializer:

with det.pytorch.init(
    hparams=hparams,
    exp_conf=exp_conf,
    distributed=distributed_context,
) as train_context:

And pass latest_checkpoint into the fit() function:

    trainer.fit(
        max_length=det.pytorch.Epoch(1),
        latest_checkpoint=latest_checkpoint,
    )

The latest_checkpoint argument specifies which checkpoint should be used to start or continue training.

Now you can use distributed training locally:

torchrun --nproc_per_node=4 train.py

Or train on your cluster:

det e create const.yaml .

Note that the above code uses the PyTorch distributed backend, but you could use the Horovod backend instead if you want to.

Want to learn more?

That’s our brief guide on how to use Determined’s PyTorch Trial API. If you want to learn more, feel free to check out the documentation, ask a question in the GitHub repo, and join our Slack community!