TensorFlow Datasets: The Bad Parts

TLDR: TensorFlow’s tf.data API is a popular approach to loading data into deep learning models. Although tf.data has a lot of powerful features, it is built around sequential access to the underlying data set. This design makes it difficult to efficiently shuffle large data sets, to shard data when doing distributed training, and to implement fault-tolerant training. We argue that random access should be a key consideration when building deep learning data APIs.

Update: We’ve released YogaDL, a library for deep learning data loading that addresses a lot of the concerns described above. Learn more in the YogaDL announcement blog post!

There are many pitfalls engineering teams can fall into when building an end-to-end enterprise deep learning platform. One of the most common problems involves data loading. Data loading during training is often overlooked, and it can have massive implications for throughput. Machine learning frameworks provide abstractions that attempt to make data loading straightforward, but peeking behind the curtain of these seemingly simple interfaces can reveal surprising problems. In this post, we’ll be taking you behind the scenes of a popular data loading API: TensorFlow Datasets.

We’ve peeked behind the curtain of TensorFlow Datasets to reveal some glaring problems

Data Loader Patterns

There are two fundamental patterns that a data loading interface can use, random access and sequential access.

Random Access

Random access is the ability to access any element of a dataset efficiently. In Python, random access is often done by indexing into a list (i.e., data[index]), which calls __getitem__() behind the scenes. PyTorch uses this approach to define the map-style dataset interface (implemented above). Random-access data loader interfaces may also require that a user specify the entire length of the dataset (__len__()).

import torch.utils.data

class RandomAccessDataset(torch.utils.data.Dataset):
    def __init__(self, data: List) -> None:
        self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int): -> Any:
        return self.data[index]

Deep learning data APIs that support random access include tf.keras.utils.Sequence and torch.utils.data.Dataset (Map Style).

Sequential Access

Sequential access is a paradigm where elements must be accessed in a predetermined order, typically through an iterator. In Python, sequential access is often implemented via iterators and the yield expression.

def sequential_dataset(data: List) -> Iterator:
    for item in data:
        yield item

Some deep learning frameworks, such as early versions of Keras, natively support expressing your data input pipeline as a Python generator. Similarly, TensorFlow Datasets are built around sequential data access. Converting a Python generator into a TensorFlow Dataset is straightforward, if a little verbose:

import itertools

def gen():
    for i in itertools.count(1):
        yield (i, [1] * i)

dataset = tf.data.Dataset.from_generator(
     (tf.int64, tf.int64),
     (tf.TensorShape([]), tf.TensorShape([None])))

In the next section, we will discuss the drawbacks of using sequential access as a data loader.

Sequential Access in TensorFlow Datasets

TensorFlow’s tf.data API makes it easy to structure your data loading code in an elegant way: you can chain together a stream of operations on a dataset using lazy initialization, and tf.data provides helper APIs to do common tasks like prefetching and parallel data loading.

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,2,3])

for element in dataset:
   print (element)
>>> tf.Tensor(1, shape=(), dtype=int32)
>>> tf.Tensor(2, shape=(), dtype=int32)
>>> tf.Tensor(3, shape=(), dtype=int32)

dataset = dataset.map(lambda x: x*2)
for element in dataset:
   print (element)
>>> tf.Tensor(2, shape=(), dtype=int32)
>>> tf.Tensor(4, shape=(), dtype=int32)
>>> tf.Tensor(6, shape=(), dtype=int32)

However, TensorFlow Dataset are fundamentally built around sequential access: each operator in a tf.data pipeline iterates over its input and produces a sequential output stream that is consumed by the next operator. The API does not support random access, which leads to some major issues trying to implement some common machine learning workflows.

>>> TypeError: 'TensorSliceDataset' object does not support indexing

>>> [1,2,3]

Data Shuffling

When training a deep learning model, the training set is often shuffled before being fed into the model — this typically improves generalization performance. If our data API only supports sequential access, how can we implement random shuffling? A simple but inefficient approach would be to read as much data as we can into memory and shuffle it there. In fact, that’s exactly what tf.data’s shuffle() method does!

This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.

For datasets that don’t fit entirely into memory (the most common case in deep learning), shuffle() doesn’t actually shuffle the full dataset! This means that shuffle() doesn’t have the intended effect in most applications. Many practitioners, including us, have made this error and seen their model’s generalization performance suffer as a result. Although it is possible to shuffle your entire dataset ahead of time by loading the data into memory or shuffling a list of filenames, many users might not realize this problem exists in their code!

Data Sharding

When doing data-parallel distributed training, each worker (typically a GPU) is trained on a fraction (or “shard” ) of the data in each batch. To handle this common task, tf.data provides a method that seems like a perfect fit: shard(n,i) splits a dataset into n shards and returns the i’th shard for further processing in the current worker.

Unfortunately, there’s a catch: shard() iterates over the entire input dataset, returning every n’th record and ignoring the rest! That means that if you apply shard() to a large dataset during distributed training, each worker in the distributed training job will end up reading the entire dataset. If you’re training a model with 64 GPUs, that means you’ll be doing 64x more disk I/O than you probably intended. It gets even worse if you’re doing on-the-fly data augmentation before the shard() operator in your pipeline — those data augmentation operations will be done redundantly by every worker.

The TensorFlow documentation acknowledges this and observes:

Generally it is best if the shard operator is used early in the dataset pipeline.

TensorFlow’s recommended approach is to create a dataset of TFRecord file names and apply shard() to this list. Each worker receives a disjoint set of files to process, which avoids any unnecessary disk I/O. This approach works, but it has two problems:

  1. You need to split your data set into a larger number of files than the number of workers in your distributed training job. If you have a large dataset stored in a small number of files, you’re out of luck. Moreover, any size imbalances between those files will result in stragglers, hurting training performance.
  2. More likely, you might not realize any of this! A lot of real-world data loading code just converts a Python generator into a TensorFlow Dataset using Dataset.from_generator(). This will appear to work okay at small scale, but will quickly run into serious performance problems as your data set grows.

Saving and Restoring Iterator State

If you want to build a deep learning training system that can recover from faults, a common approach is to use checkpoint-restart: periodically save the state of the job to a checkpoint file, and, when a failure occurs, restore the job from the most recent checkpoint. This is particularly important for training jobs that can last hours or even days. However, saving and restoring a training job requires knowing the job’s position in the dataset.

For example, if you’re training on a 100,000-element dataset and the most recent checkpoint was taken after the 50,000th record, you’ll want to ensure that training resumes at the 50,001th record instead of restarting from the beginning. With a random-access interface, this is trivial: just save the index position as an integer and pick up where you left off when restoring. With a sequential-access interface, this can be really hard. Python generators are notoriously hard to pickle. TensorFlow datasets have experimental support for checkpointing and restoring some types of datasets, but not those created with tf.data.Dataset.from_generator().


All of the above problems exist because tf.data is built around sequential access. So do yourself a favor: don’t restrict your entire data loading codebase to sequential access patterns. If you would like to apply prefetching and a functional style with a sequential access pattern, you can always wrap a random access interface as follows:

dataset = RandomAccessDataset()

def sequential_access_dataset() -> Iterator:
    for index in range(len(dataset)):
        yield dataset[index]

It’s easy to go from random to sequential! But going the other way is much harder.

TensorFlow Datasets are currently the recommended way to load data in TensorFlow, and it doesn’t look like that is going to change any time soon. Many readers of this article might find themselves in the unfortunate position of being stuck with TensorFlow Datasets due to forces beyond their control. If this is you, we’ve been hard at work on a solution that will make your life easier—stay tuned for more next week!