Use Batch Inference in Determined Part 2: Validate Trained Models with Metrics

We are pleased to announce the launch of new functionality within the Batch Processing API!

Previously, we launched the Batch Processing API for a text embedding computation use case, where users can shard the dataset across multiple workers and apply process logic to batches of data to create vector embeddings for various documents. Today, this API enables our users to tackle another challenge on their ML journey – validating a trained model.

In this blog, we will introduce the new features and cover one specific use case in more depth: conducting model validation after training in a Jupyter Notebook environment.


Why run additional validation on your model outside a training job?

There are a few reasons for this. For example, you might want to evaluate a pre-trained model against a different dataset, or you might want to use domain-specific metrics for a business use case. During the typical train-validation loop, standard validation metrics (like accuracy) are generated. But sometimes you need more than these standard validation metrics to decide which model goes into deployment, and it doesn’t really make sense to evaluate a model on specialized metrics in a training run. Or sometimes additional custom metrics are chosen after models are already trained.

To address this, we added the ability to track custom metrics that are generated in an inference run and associate them with the models used to generate them.

💡 The key addition of this new set of functionality is that you can now associate built models (in the form of a checkpoint or a model version) with inference metrics that were run over them.


Use Case: Validate a trained MNIST model on a custom evaluation metric

In this example, we will run batch inference over a set of models trained on the MNIST dataset. We will walk through the process of writing an inference job that will count how many “9”’s were predicted by our trained models. We will start with a set of models saved into a Determined Model. To learn more about how to save models into the Model Registry, please check the model registry documentation. The full example can be found here: Inference Metrics MNIST example.


1) Reporting the metrics

We will extend TorchBatchProcessor to calculate custom inference metrics in a distributed manner. As explained in Part 1, the Batch Processing API works like the Trial APIs, e.g. by extending the TorchBatchProcessor class.


To start, we initialize the InferenceProcessor:

class InferenceProcessor(experimental.TorchBatchProcessor):
    def __init__(self, context):
        hparams = self.context.get_hparams()
        model = client.get_model(hparams.get("model_name"))
        model_version = model.get_version(hparams.get("model_version"))
        self.context.report_task_using_model_version(model_version)


Importantly, the new addition is that metrics can now be associated with a particular model version. This is done with the call context.report_task_using_model_version(model_version).

From then on, all metrics generated in the batch inference run will be associated with this model version for analysis later (covered in Section 2 below).


In the on_finish function, the user can now report arbitrary groups of custom metrics in the form of a dictionary with the call context.report_metrics(). One example of this might look like:

def on_finish(self):
    self.context.report_metrics(
    group="inference",
    steps_completed=self.rank,
    metrics=my_metrics,
    )


2) Fetching the Metrics

Once you’ve reported the metrics like shown, then, from a Determined Environment like a Jupyter Notebook, you can access this model version or checkpoint at a later time to aggregate and analyze its inference metrics. The generated metrics can then be processed by filtering, grouping, and sorting.


For example:

from determined.experimental import client
 
model = client.get_model("<YOUR_MODEL_NAME_HERE>")
model_version = model.get_version(1)
 
metrics = model_version.get_metrics()  # Generator of all associated metrics


Summary

This new inference metrics functionality helps you address pain points in use cases outside of training, for example, model validation. Any existing models that a user would like to validate can now be organized and aggregated for analysis with any custom metrics the user designs.

Now this functionality is ready for you to try. To get started, please visit our documentation and join our Slack Community! As always, we are here to support you and are open to your feedback.