Using Huge, Heterogenous Datasets in TensorFlow
27 Feb 2022When using TensorFlow, the size of the dataset can be so big sometimes such that it cannot be stored in the main memory completely.
TensorFlow has provided the tf.data.Dataset
API to reduce memory footprint and improve the efficiency when working with big datasets.
However, the examples in the documentation are built around common data types such as text, image, etc.
It is unclear how to adapt these approaches to other types of huge custom datasets.
In this post, I discuss a method that I developed for huge datasets containing heterogenous data types (based on https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
).
This technique allows one to store the huge dataset on hard drive, which reduces memory consumption.
It also allows one to use the dataset APIs for efficient dataset transformations and multiple-GPU training.
(Examples are tested on Python 3.8+TensorFlow 2.8.0)
A huge, heterogenous dataset
A heterogeneous dataset is defined as follows:
- Each sample in the dataset is a key-value pair (a
dict
object in Python) - The keys in each sample are the same across the entire dataset
- For each key, the corresponding value is an array (or a scalar, which is an array with size 1)
- The data type of the arrays is the same across samples for a given key
- For a given key, the arrays should have the same amount of dimensions
- For a given key, the shape of the last array dimension should be the same
Despite having 6 conditions, I think most datasets can be represented in this way. An example of a sample from such a dataset is:
{
'thumbnail': np.zeros((100, 100, 3), np.uint8),
'waveform': np.zeros((100000, 2), np.int16),
'spectrum': np.zeros((22050, 2), np.float32),
'label': 0
}
This can be thought of as an example of information extracted from a music audio file:
- The
thumbnail
is an \(100\times 100\) RGB image - The waveform is the time-domain samples extracted from the audio file
- The spectrum is computed from the frequency-domain representation of the audio signal
- The label is the classification target of ths sample
The dataset is “heterogeneous” because it contains many fields of different data types.
In order to proceed, one should preprocess the dataset so that data samples can be represented as a list of such dict
objects.
Note that it is not required to transform all data samples at once, the dict
object can be obtained from a generator.
Converting the dataset to TensorFlow format
In the following example, we convert the data sample dict
objects into a TFRecordDataset
dataset.
- The data sample
dict
objects are fetched fromsamples
variable, which can be a list or a generator - The dataset is written in
dataset.tfrecords
- The dataset metadata is stored in
dataset_metadata.json
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # consider suppressing the use of GPU
import tensorflow as tf
import numpy as np
import json
# wrapper for binary features (in TFRecordDataset)
def bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# wrapper for int64 features (in TFRecordDataset)
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# the input sample is a `dict` object
def transform_sample(sample):
# this dictionary is used to store the data type and shape information of each key
key_info = dict()
for key, val in sample.items():
if isinstance(val, np.ndarray):
# generic handler for np.ndarray
key_info[key] = dict(dtype=repr(val.dtype), shape=val.shape)
elif isinstance(val, int):
# special handler for int64 type
key_info[key] = '@@int'
else:
raise RuntimeError('unknown data type')
# this dictionary is the one containing the data sample
tf_sample = dict()
for key, val in sample.items():
if isinstance(val, np.ndarray):
# generic handler for np.ndarray
tensor_val = tf.convert_to_tensor(val)
tensor_bin = tf.io.serialize_tensor(tensor_val)
tf_sample[key] = bytes_feature(tensor_bin)
elif isinstance(val, int):
# special handler for int64 type
tf_sample[key] = int64_feature(val)
else:
raise RuntimeError('unknown data type')
return tf.train.Example(features=tf.train.Features(feature=tf_sample)), key_info
samples = [] # a list of samples, or a generator of samples
key_info = None # used to keep track of the data type and shape information of each key
# (only need the last one since they are the same across keys)
num_samples = 0 # keep track of the number of samples
# the samples are serialized and written to this file
with tf.io.TFRecordWriter('dataset.tfrecords') as writer:
# the dataset building loop
for sample in samples:
try:
tf_sample, key_info = transform_sample(sample)
writer.write(tf_sample.SerializeToString())
num_samples += 1
except Exception as e:
print(f'an error occurred for: {repr(e)}')
# build and save dataset metadata
dataset_metadata = {
'num_samples' : num_samples,
'key_info' : key_info
}
with open('dataset_metadata.json', 'w') as outfile:
json.dump(dataset_metadata, outfile, indent=2)
Additional comments:
- One can use dataset building loops to store the training set and testing set in different files.
- This process can be parallelized. More concretely, one can run multiple copies of the dataset building loop and store the results in different files. These distributed files can be merged to one dataset later.
Using the dataset
import os
import tensorflow as tf
import numpy as np
import json
import re
# load back dataset metadata
with open('dataset_metadata.json', 'r') as infile:
dataset_metadata = json.load(infile)
# compute number of training steps
batch_size = 4
num_train_steps = int(np.ceil(dataset_metadata['num_samples']/batch_size))
# generate a dataset description for deserialization
dataset_description = dict()
for key, val in dataset_metadata['key_info'].items():
if val == '@@int':
dataset_description[key] = tf.io.FixedLenFeature([], tf.int64)
else:
dataset_description[key] = tf.io.FixedLenFeature([], tf.string)
# define the first transformation to deserialize the dataset
def parse_dataset(raw_example):
parsed = tf.io.parse_single_example(raw_example, dataset_description)
for key, val in dataset_metadata['key_info'].items():
# this step is only needed for non-integer type
if val != '@@int':
s_res = re.search("'(.*?)'", val['dtype']) # search for the data type
type_str = s_res.group(1)
parsed[key] = tf.io.parse_tensor(parsed[key], getattr(tf, type_str))
return parsed
# define the second transformation to split a sample into the sample and its label
def export_sample_and_label(x):
return x, x['label']
# define the third to specify the shapes of the keys
# this is mandatory, otherwise errors may occur in training
# see https://github.com/tensorflow/tensorflow/issues/32912#issuecomment-550363802
def fix_shape(x, y):
for key, val in dataset_metadata['key_info'].items():
if val != '@@int':
x[key].set_shape(val['shape'])
return x, y
# load the dataset back
# if there are multiple files (e.g., generated by distributed dataset generation),
# specify them in the `filenames` list
dataset = tf.data.TFRecordDataset(filenames=['dataset.tfrecords']).map(parse_dataset)\
.map(export_sample_and_label)\
.map(fix_shape)
train_ds = dataset.batch(batch_size).prefetch(batch_size).repeat()
'''
what `train_ds` looks like:
<RepeatDataset element_spec=({'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None),
'spectrum': TensorSpec(shape=(None, None, 2), dtype=tf.float32, name=None),
'thumbnail': TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8, name=None),
'waveform': TensorSpec(shape=(None, None, 2), dtype=tf.int16, name=None)},
TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
'''
# define a custom model
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.proj_a = tf.keras.layers.Dense(10)
self.proj_b = tf.keras.layers.Dense(10)
self.proj_c = tf.keras.layers.Dense(10)
self.last_layer = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x, training=False):
# do some dummy operations to each field
a = tf.cast(x['thumbnail'], tf.float32)
b = tf.cast(x['waveform'], tf.float32)
c = x['spectrum']
value = [
self.proj_a(a[:, :, 0, 0]),
self.proj_b(b[:, :, 0]),
self.proj_c(c[:, :, 0])
]
value = tf.concat(value, axis=1)
return self.last_layer(value)
# initialize the dummy model and start training
model = MyModel()
model.compile(loss='binary_crossentropy', optimizer='Adam')
model.fit(train_ds, epochs=2,steps_per_epoch=num_train_steps)
Additional comments:
- In this
fix_shape
implementation, the array shape for a given key is assumed to be the same. It is possible to support variadic input length by setting the corresponding dimension toNone
.
Related links
- https://www.tensorflow.org/guide/data
- https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
Complete dummy example Python files
dataset generation
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # consider suppressing the use of GPU
import tensorflow as tf
import numpy as np
import json
# wrapper for binary features (in TFRecordDataset)
def bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# wrapper for int64 features (in TFRecordDataset)
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# the input sample is a `dict` object
def transform_sample(sample):
# this dictionary is used to store the data type and shape information of each key
key_info = dict()
for key, val in sample.items():
if isinstance(val, np.ndarray):
# generic handler for np.ndarray
key_info[key] = dict(dtype=repr(val.dtype), shape=val.shape)
elif isinstance(val, int):
# special handler for int64 type
key_info[key] = '@@int'
else:
raise RuntimeError('unknown data type')
# this dictionary is the one containing the data sample
tf_sample = dict()
for key, val in sample.items():
if isinstance(val, np.ndarray):
# generic handler for np.ndarray
tensor_val = tf.convert_to_tensor(val)
tensor_bin = tf.io.serialize_tensor(tensor_val)
tf_sample[key] = bytes_feature(tensor_bin)
elif isinstance(val, int):
# special handler for int64 type
tf_sample[key] = int64_feature(val)
else:
raise RuntimeError('unknown data type')
return tf.train.Example(features=tf.train.Features(feature=tf_sample)), key_info
# create some dummy samples
samples = []
for _ in range(8):
samples.append({
'thumbnail': np.zeros((100, 100, 3), np.uint8),
'waveform': np.zeros((100000, 2), np.int16),
'spectrum': np.zeros((22050, 2), np.float32),
'label': 0
})
key_info = None # used to keep track of the data type and shape information of each key
# (only need the last one since they are the same across keys)
num_samples = 0 # keep track of the number of samples
# the samples are serialized and written to this file
with tf.io.TFRecordWriter('dataset.tfrecords') as writer:
# the dataset building loop
for sample in samples:
try:
tf_sample, key_info = transform_sample(sample)
writer.write(tf_sample.SerializeToString())
num_samples += 1
except Exception as e:
print(f'an error occurred for: {repr(e)}')
# build and save dataset metadata
dataset_metadata = {
'num_samples' : num_samples,
'key_info' : key_info
}
with open('dataset_metadata.json', 'w') as outfile:
json.dump(dataset_metadata, outfile, indent=2)
Dataset testing
import os
import tensorflow as tf
import numpy as np
import json
import re
# load back dataset metadata
with open('dataset_metadata.json', 'r') as infile:
dataset_metadata = json.load(infile)
# compute number of training steps
batch_size = 4
num_train_steps = int(np.ceil(dataset_metadata['num_samples']/batch_size))
# generate a dataset description for deserialization
dataset_description = dict()
for key, val in dataset_metadata['key_info'].items():
if val == '@@int':
dataset_description[key] = tf.io.FixedLenFeature([], tf.int64)
else:
dataset_description[key] = tf.io.FixedLenFeature([], tf.string)
# define the first transformation to deserialize the dataset
def parse_dataset(raw_example):
parsed = tf.io.parse_single_example(raw_example, dataset_description)
for key, val in dataset_metadata['key_info'].items():
# this step is only needed for non-integer type
if val != '@@int':
s_res = re.search("'(.*?)'", val['dtype']) # search for the data type
type_str = s_res.group(1)
parsed[key] = tf.io.parse_tensor(parsed[key], getattr(tf, type_str))
return parsed
# define the second transformation to split a sample into the sample and its label
def export_sample_and_label(x):
return x, x['label']
# define the third to specify the shapes of the keys
# this is mandatory, otherwise errors may occur in training
# see https://github.com/tensorflow/tensorflow/issues/32912#issuecomment-550363802
def fix_shape(x, y):
for key, val in dataset_metadata['key_info'].items():
if val != '@@int':
x[key].set_shape(val['shape'])
return x, y
# load the dataset back
# if there are multiple files (e.g., generated by distributed dataset generation),
# specify them in the `filenames` list
dataset = tf.data.TFRecordDataset(filenames=['dataset.tfrecords']).map(parse_dataset)\
.map(export_sample_and_label)\
.map(fix_shape)
train_ds = dataset.batch(batch_size).prefetch(batch_size).repeat()
# define a custom model
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.proj_a = tf.keras.layers.Dense(10)
self.proj_b = tf.keras.layers.Dense(10)
self.proj_c = tf.keras.layers.Dense(10)
self.last_layer = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x, training=False):
# do some dummy operations to each field
a = tf.cast(x['thumbnail'], tf.float32)
b = tf.cast(x['waveform'], tf.float32)
c = x['spectrum']
value = [
self.proj_a(a[:, :, 0, 0]),
self.proj_b(b[:, :, 0]),
self.proj_c(c[:, :, 0])
]
value = tf.concat(value, axis=1)
return self.last_layer(value)
# initialize the dummy model and start training
model = MyModel()
model.compile(loss='binary_crossentropy', optimizer='Adam')
model.fit(train_ds, epochs=2,steps_per_epoch=num_train_steps)