How to Structure your Code to Build a Scalable Deep Learning Model

Most ML code that floats around in open source repositories isn’t built to scale. To demonstrate this, we’re going to look at two very similar codebases that implement a classifier for the MNIST dataset. Both codebases use Keras — one codebase uses Keras itself, while the second uses Determined’s Keras API. We’ll show how one codebase turns into a nightmare as we add more capabilities, while the other will scale with almost no changes.

First, some Python boilerplate:

Traditional
def main():
    pass

if __name__ == '__main__':
    main()
Determined
class MNISTTrial(TFKerasTrial):
    def __init__(self, context: TFKerasTrialContext) -> None:
        self.context = context

Now let’s load the data we’ll use to train and test our model. Since we’re using a simple dataset (MNIST) this will be quick, but this gets pretty complicated with large, custom datasets.

Traditional
def load_data():
    img_rows, img_cols = 28, 28
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train /= 255
    x_test /= 255
    return (x_train, y_train), (x_test, y_test)
Determined
    def build_training_data_loader(self) -> InputData:
        (train_images, train_labels), _ = mnist.load_data()
        train_images = train_images / 255.0

        return train_images, train_labels

    def build_validation_data_loader(self) -> InputData:
        _, (test_images, test_labels) = mnist.load_data()
        test_images = test_images / 255.0

        return test_images, test_labels

Now the fun part, defining the model itself:

Traditional
def build_model(input_shape):
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.RMSprop(),
                  metrics=['accuracy'])
    return model
Determined
    def build_model(self):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=(3, 3),
                         activation='relu',
                         input_shape=input_shape))
        model.add(Conv2D(64, (3, 3), activation='relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(self.context.get_hparam("dropout1")))
        model.add(Flatten())
        model.add(Dense(128, activation='relu'))
        model.add(Dropout(self.context.get_hparam("dropout2")))
        model.add(Dense(num_classes, activation='softmax'))
        model = self.context.wrap_model(model)
        model.compile(
            optimizer=RMSprop(
                lr=self.context.get_hparam("learning_rate"),
            ),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
        )
        return model

Finally, we’ll add the last few pieces we need to actually train the model locally:

Traditional
def main():
    batch_size = 128
    num_classes = 10
    epochs = 12
    input_shape = (28, 28, 1)
    (x_train, y_train), (x_test, y_test) = load_data()
    model = build_model(input_shape)
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))
    score = model.evaluate(x_test, y_test, verbose=0)
Determined
description: mnist_tf_keras_const
hyperparameters:
  learning_rate: 1e-4
  dropout1: .25
  dropout2: .5
searcher:
  name: single
  metric: accuracy
  smaller_is_better: true
  max_steps: 50
entrypoint: model_def:MNISTTrial

The code sample on the left looks like most of the ML code that exists across the internet, including throughout dozens of popular repositories (e.g., StyleGan, Wide ResNets). The code on the right is structured — it’s built to conform to a specific API, in this case the Determined Trial Interface. Both are almost identical — they took the same amount of time to write, and are essentially the same length. They can both be used to train a model with a single command.

The hard part starts when you need to scale this code out for a larger problem than MNIST. You’ll be working with a large, custom dataset and you’ll need a bigger model to fit that dataset. Suddenly your local machine won’t be able to train the model and you’ll need to investigate distributed training with multiple GPUs. As you perfect your model you’ll need to perform some sort of hyperparameter search. In most cases, your code will evolve something like this:

alt text

To meet the challenges of your larger problem, you’ll need to layer a lot of complexity onto your codebase. Maybe you’ll distribute training using Horovod and a GPU cluster, or maybe you’ll set up a hyperparameter search framework like Ray Tune to tune your model. Each of these steps comes with huge overhead — you’ll need to write new code to integrate these pieces, configure complex environments, and even then without a large IT team to help you you’ll likely run into resiliency issues.

And it gets worse. In reality, your code will probably look more like this:

alt text

Pictured: Smashed code together

As more complexity gets layered in, most of the time there isn’t a way to cleanly separate the different pieces. There will be infrastructure code everywhere, your hyperparameter tuning will get tangled up with your training loops, and you’ll have 6 different places where distributed training logic comes in. Imagine having to hand that off to a coworker! Working on these jumbled code bases can be a nightmare, and you’ve completely lost the clean model definition you had at the beginning. Even worse, the next time you build a model, you’ll go through all of these same challenges again!

Now let’s look at how your structured model will evolve:

alt text

You just need a config file. By conforming to a structured API, you can take advantage of a wide range of tools Determined has to offer (like automatic hyperparameter tuning, distributed training, and fault tolerance), without having to manually integrate all of these pieces into your code. By simply modifying your config file, you’ll be able to specify sophisticated experiments, like cluster-wide distributed training or massively parallel hyperparameter tuning.

For essentially the same initial investment, you get distributed training, hyperparameter tuning, fault tolerance, better code readability, and built-in cluster management. Determined makes it that easy. Try it out now or check out Determined on GitHub!

Recent Posts