October 04, 2023
We are pleased to announce the launch of new functionality within the Batch Processing API!
Previously, we launched the Batch Processing API for a text embedding computation use case, where users can shard the dataset across multiple workers and apply process logic to batches of data to create vector embeddings for various documents. Today, this API enables our users to tackle another challenge on their ML journey – validating a trained model.
In this blog, we will introduce the new features and cover one specific use case in more depth: conducting model validation after training in a Jupyter Notebook environment.
There are a few reasons for this. For example, you might want to evaluate a pre-trained model against a different dataset, or you might want to use domain-specific metrics for a business use case. During the typical train-validation loop, standard validation metrics (like accuracy) are generated. But sometimes you need more than these standard validation metrics to decide which model goes into deployment, and it doesn’t really make sense to evaluate a model on specialized metrics in a training run. Or sometimes additional custom metrics are chosen after models are already trained.
To address this, we added the ability to track custom metrics that are generated in an inference run and associate them with the models used to generate them.
In this example, we will run batch inference over a set of models trained on the MNIST dataset. We will walk through the process of writing an inference job that will count how many “9”’s were predicted by our trained models.
We will start with a set of models saved into a Determined
Model. To learn more about how to save models into the Model Registry, please check the model registry documentation.
The full example can be found here: Inference Metrics MNIST example.
We will extend
TorchBatchProcessor to calculate custom inference metrics in a distributed manner. As explained in Part 1, the Batch Processing API works like the Trial APIs, e.g. by extending the
To start, we initialize the
class InferenceProcessor(experimental.TorchBatchProcessor): def __init__(self, context): hparams = self.context.get_hparams() model = client.get_model(hparams.get("model_name")) model_version = model.get_version(hparams.get("model_version")) self.context.report_task_using_model_version(model_version)
Importantly, the new addition is that metrics can now be associated with a particular model version.
This is done with the call
From then on, all metrics generated in the batch inference run will be associated with this model version for analysis later (covered in Section 2 below).
on_finish function, the user can now report arbitrary groups of custom metrics in the form of a dictionary with the call
context.report_metrics(). One example of this might look like:
def on_finish(self): self.context.report_metrics( group="inference", steps_completed=self.rank, metrics=my_metrics, )
Once you’ve reported the metrics like shown, then, from a Determined Environment like a Jupyter Notebook, you can access this model version or checkpoint at a later time to aggregate and analyze its inference metrics. The generated metrics can then be processed by filtering, grouping, and sorting.
from determined.experimental import client model = client.get_model("<YOUR_MODEL_NAME_HERE>") model_version = model.get_version(1) metrics = model_version.get_metrics() # Generator of all associated metrics
This new inference metrics functionality helps you address pain points in use cases outside of training, for example, model validation. Any existing models that a user would like to validate can now be organized and aggregated for analysis with any custom metrics the user designs.