# Data loading
As of tf2.3, official recommended way is using TFRecord, which pre-stores processed tensors and uses them later in data pipeline.
- saves the time to generate tensors
- aggregate quite a few images in single file
- json-similar serializing mechanism, protocol buffer
- streaming, so no random access
TFRecord Inspection
raw_dataset = tf.data.TFRecordDataset("path-to-file")
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
# from google.protobuf.json_format import MessageToJson
# json_string = MessageToJson(tf.train.Example.FromString(example))
# json.loads(json_string)
There are other ways:
tf.keras.util.Sequence
: follows the loading philosophy as of pytorch.class Dataset(tf.keras.utils.Sequence): def __init__(): pass def __getitem__(self, idx): pass def __len__(self): pass
tf.data.Dataset.from_generator
: either writes a generator function or converts a sequence to generator.
output_types
andoutput_shapes
are a must# Converts sequence to generator def seq_to_dataset(self, seq): def generator(): multi_enqueuer = tf.keras.utils.OrderedEnqueuer(seq) multi_enqueuer.start() while True: batch_xs, batch_ys = next(multi_enqueuer.get()) batch_ys = tf.constant(batch_ys, dtype=tf.float32) yield batch_xs, batch_ys # if sequence, both item should be tuple instead of list x_types = ((tf.int32,) * 10 + (tf.float32,)) * 2 x_shapes = (tf.TensorShape([None, None]),) * 22 # some_internal_fn(input_type, shallow_type), shallow_type is from the following's output_types # since shallow_type(output_types) only supports tuple, input_type must be tuple # sequence object must yield tuple instead of list, structure can be nested dataset = tf.data.Dataset.from_generator(generator, output_types=(x_types, tf.float32), output_shapes=(x_shapes, tf.TensorShape([None]))) return dataset
tf.data.Dataset
:from_x
is like pytorch's__init__
,map
is like__getitem__
, but every operation is based on tensors. For non-tensor operations, usetf.py_function
tf.data.Dataset.list_files vs tf.data.Dataset.from_tensor_slices
# as of tf2.1 # The file_pattern argument should be a small number of glob patterns. If your filenames have already been globbed, use Dataset.from_tensor_slices(filenames) instead, as re-globbing every filename with list_files may result in poor performance with remote storage systems. list_files( file_pattern, shuffle=None, seed=None ) from_tensor_slices( tensors )
tf.py_function vs tf.function
tf.py_function enables non-tf operations in graph mode, however, it cannot be serialized and ported. Whereas, tf.function(or this link) is a utility to convert operations to graph mode and it requires tf operations.
The output type of
tf.py_function
cannot be a nested sequence. When using atf.py_function
with thetf.data
API, however, you need to create a wrapper function, and you can nest the outputs in that function.def func(x): # some code to convert x to non-tensors # some dictionary operation, bla bla # if this function and tf_func are instance methods, instance variable updates here takes no effect on tf_func # as execution sequence is complex # make sure return type is consistent with output_types in tf.py_function pass # merge then split again def tf_func(x): # output_types only supports 1-d tuples or lists, without nesting # func must return results in 1-d sequence # final result is casted to list result = tf.py_function(func, [x], output_types) for every_item in result: # assume result is a tensor not a plain list # To pass compile check in tf graph # TensorFlow can't figure out the shape (it would require analyzing Python code of function body) # You should instead give the shape manually every_item.set_shape(tf.TensorShape([s1, s2])) x = result[:11] y = result[-1] # must return sequence in tuple instead of list, as map only supports tuples return tuple(x), y dataset = tf.data.Dataset.range(3).map(tf_func)
Note that TFRecord/tf.data.Dataset uses sequential accessing, while keras.Sequence and torch use random accessing. If use tf.data.Dataset, it is hard to resume the exact iteration in an epoch. Only some weight resumptions are supported:
"""Resume an iterator is supported"""
import tensorflow as tf
ds = tf.data.Dataset.range(100).shuffle(100, seed=42, reshuffle_each_iteration=False)
it = iter(ds) # iterator state + some dataset state
ckpt = tf.train.Checkpoint(foo=it)
mgr = tf.train.CheckpointManager(ckpt, '/tmp/x', max_to_keep=3)
for _ in range(3): print(next(it))
# tf.Tensor(72, shape=(), dtype=int64)
# tf.Tensor(83, shape=(), dtype=int64)
# tf.Tensor(19, shape=(), dtype=int64)
mgr.save()
print("Saved to {}".format(mgr.latest_checkpoint))
for _ in range(3): print(next(it))
# Saved to /tmp/x/ckpt-1
# tf.Tensor(74, shape=(), dtype=int64)
# tf.Tensor(33, shape=(), dtype=int64)
# tf.Tensor(93, shape=(), dtype=int64)
ckpt.restore(mgr.latest_checkpoint)
print("Restored from {}".format(mgr.latest_checkpoint))
for _ in range(3): print(next(it))
# Restored from /tmp/x/ckpt-1
# tf.Tensor(74, shape=(), dtype=int64)
# tf.Tensor(33, shape=(), dtype=int64)
# tf.Tensor(93, shape=(), dtype=int64)
"""Dataset is not fully resumed"""
ds = tf.data.Dataset.range(100).shuffle(100, seed=42, reshuffle_each_iteration=True)
ckpt = tf.train.Checkpoint(foo=it)
mgr = tf.train.CheckpointManager(ckpt, '/tmp/x', max_to_keep=3)
enumerate(ds) # A1
mgr.save()
enumerate(ds) # A2
ckpt.restore(mgr.latest_checkpoint)
enumerate(ds) # A1 again instead of A2, dataset starts over