June 12, 2020
Most ML code that floats around in open source repositories isn’t built to scale. To demonstrate this, we’re going to look at two very similar codebases that implement a classifier for the MNIST dataset. Both codebases use Keras — one codebase uses Keras itself, while the second uses Determined’s Keras API. We’ll show how one codebase turns into a nightmare as we add more capabilities, while the other will scale with almost no changes.
First, some Python boilerplate:
def main(): pass if __name__ == '__main__': main()
class MNISTTrial(TFKerasTrial): def __init__(self, context: TFKerasTrialContext) -> None: self.context = context
Now let’s load the data we’ll use to train and test our model. Since we’re using a simple dataset (MNIST) this will be quick, but this gets pretty complicated with large, custom datasets.
def load_data(): img_rows, img_cols = 28, 28 (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train /= 255 x_test /= 255 return (x_train, y_train), (x_test, y_test)
def build_training_data_loader(self) -> InputData: (train_images, train_labels), _ = mnist.load_data() train_images = train_images / 255.0 return train_images, train_labels def build_validation_data_loader(self) -> InputData: _, (test_images, test_labels) = mnist.load_data() test_images = test_images / 255.0 return test_images, test_labels
Now the fun part, defining the model itself:
def build_model(input_shape): model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.RMSprop(), metrics=['accuracy']) return model
def build_model(self): model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(self.context.get_hparam("dropout1"))) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(self.context.get_hparam("dropout2"))) model.add(Dense(num_classes, activation='softmax')) model = self.context.wrap_model(model) model.compile( optimizer=RMSprop( lr=self.context.get_hparam("learning_rate"), ), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")], ) return model
Finally, we’ll add the last few pieces we need to actually train the model locally:
def main(): batch_size = 128 num_classes = 10 epochs = 12 input_shape = (28, 28, 1) (x_train, y_train), (x_test, y_test) = load_data() model = build_model(input_shape) model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test)) score = model.evaluate(x_test, y_test, verbose=0)
description: mnist_tf_keras_const hyperparameters: learning_rate: 1e-4 dropout1: .25 dropout2: .5 searcher: name: single metric: accuracy smaller_is_better: true max_steps: 50 entrypoint: model_def:MNISTTrial
The code sample on the left looks like most of the ML code that exists across the internet, including throughout dozens of popular repositories (e.g., StyleGan, Wide ResNets). The code on the right is structured — it’s built to conform to a specific API, in this case the Determined Trial Interface. Both are almost identical — they took the same amount of time to write, and are essentially the same length. They can both be used to train a model with a single command.
The hard part starts when you need to scale this code out for a larger problem than MNIST. You’ll be working with a large, custom dataset and you’ll need a bigger model to fit that dataset. Suddenly your local machine won’t be able to train the model and you’ll need to investigate distributed training with multiple GPUs. As you perfect your model you’ll need to perform some sort of hyperparameter search. In most cases, your code will evolve something like this:
To meet the challenges of your larger problem, you’ll need to layer a lot of complexity onto your codebase. Maybe you’ll distribute training using Horovod and a GPU cluster, or maybe you’ll set up a hyperparameter search framework like Ray Tune to tune your model. Each of these steps comes with huge overhead — you’ll need to write new code to integrate these pieces, configure complex environments, and even then without a large IT team to help you you’ll likely run into resiliency issues.
And it gets worse. In reality, your code will probably look more like this:
As more complexity gets layered in, most of the time there isn’t a way to cleanly separate the different pieces. There will be infrastructure code everywhere, your hyperparameter tuning will get tangled up with your training loops, and you’ll have 6 different places where distributed training logic comes in. Imagine having to hand that off to a coworker! Working on these jumbled code bases can be a nightmare, and you’ve completely lost the clean model definition you had at the beginning. Even worse, the next time you build a model, you’ll go through all of these same challenges again!
Now let’s look at how your structured model will evolve:
You just need a config file. By conforming to a structured API, you can take advantage of a wide range of tools Determined has to offer (like automatic hyperparameter tuning, distributed training, and fault tolerance), without having to manually integrate all of these pieces into your code. By simply modifying your config file, you’ll be able to specify sophisticated experiments, like cluster-wide distributed training or massively parallel hyperparameter tuning.
For essentially the same initial investment, you get distributed training, hyperparameter tuning, fault tolerance, better code readability, and built-in cluster management. Determined makes it that easy. Try it out now or check out Determined on GitHub!