# Data loading

As of tf2.3, official recommended way is using TFRecord, which pre-stores processed tensors and uses them later in data pipeline.

  • saves the time to generate tensors
  • aggregate quite a few images in single file
  • json-similar serializing mechanism, protocol buffer
  • streaming, so no random access

TFRecord Inspection

raw_dataset = tf.data.TFRecordDataset("path-to-file")

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    # from google.protobuf.json_format import MessageToJson
    # json_string = MessageToJson(tf.train.Example.FromString(example))
    # json.loads(json_string)

There are other ways:

  • tf.keras.util.Sequence: follows the loading philosophy as of pytorch.

    class Dataset(tf.keras.utils.Sequence):
      def __init__():
      def __getitem__(self, idx):
      def __len__(self):
  • tf.data.Dataset.from_generator: either writes a generator function or converts a sequence to generator.
    output_types and output_shapes are a must

    # Converts sequence to generator
    def seq_to_dataset(self, seq):
      def generator():
          multi_enqueuer = tf.keras.utils.OrderedEnqueuer(seq)
          while True:
              batch_xs, batch_ys = next(multi_enqueuer.get())
              batch_ys = tf.constant(batch_ys, dtype=tf.float32)
              yield batch_xs, batch_ys # if sequence, both item should be tuple instead of list
      x_types = ((tf.int32,) * 10 + (tf.float32,)) * 2
      x_shapes = (tf.TensorShape([None, None]),) * 22
      # some_internal_fn(input_type, shallow_type), shallow_type is from the following's output_types
      # since shallow_type(output_types) only supports tuple, input_type must be tuple
      # sequence object must yield tuple instead of list, structure can be nested
      dataset = tf.data.Dataset.from_generator(generator,
                                               output_types=(x_types, tf.float32),
                                               output_shapes=(x_shapes, tf.TensorShape([None])))
      return dataset
  • tf.data.Dataset: from_x is like pytorch's __init__, map is like __getitem__, but every operation is based on tensors. For non-tensor operations, use tf.py_function

    tf.data.Dataset.list_files vs tf.data.Dataset.from_tensor_slices

    # as of tf2.1
    # The file_pattern argument should be a small number of glob patterns. If your filenames have already been globbed, use Dataset.from_tensor_slices(filenames) instead, as re-globbing every filename with list_files may result in poor performance with remote storage systems.
        file_pattern, shuffle=None, seed=None

    tf.py_function vs tf.function

    tf.py_function enables non-tf operations in graph mode, however, it cannot be serialized and ported. Whereas, tf.function(or this link) is a utility to convert operations to graph mode and it requires tf operations.

    The output type of tf.py_function cannot be a nested sequence. When using a tf.py_function with the tf.data API, however, you need to create a wrapper function, and you can nest the outputs in that function.

    def func(x):
      # some code to convert x to non-tensors
      # some dictionary operation, bla bla
      # if this function and tf_func are instance methods, instance variable updates here takes no effect on tf_func
      # as execution sequence is complex
      # make sure return type is consistent with output_types in tf.py_function
    # merge then split again
    def tf_func(x):
      # output_types only supports 1-d tuples or lists, without nesting
      # func must return results in 1-d sequence
      # final result is casted to list
      result = tf.py_function(func, [x], output_types)
      for every_item in result:
        # assume result is a tensor not a plain list
        # To pass compile check in tf graph
        # TensorFlow can't figure out the shape (it would require analyzing Python code of function body)
        # You should instead give the shape manually
        every_item.set_shape(tf.TensorShape([s1, s2]))
      x = result[:11]
      y = result[-1]
      # must return sequence in tuple instead of list, as map only supports tuples
      return tuple(x), y
    dataset = tf.data.Dataset.range(3).map(tf_func)

Note that TFRecord/tf.data.Dataset uses sequential accessing, while keras.Sequence and torch use random accessing. If use tf.data.Dataset, it is hard to resume the exact iteration in an epoch. Only some weight resumptions are supported:

"""Resume an iterator is supported"""
import tensorflow as tf
ds = tf.data.Dataset.range(100).shuffle(100, seed=42, reshuffle_each_iteration=False)
it = iter(ds) # iterator state + some dataset state
ckpt = tf.train.Checkpoint(foo=it)
mgr = tf.train.CheckpointManager(ckpt, '/tmp/x', max_to_keep=3)

for _ in range(3): print(next(it))
# tf.Tensor(72, shape=(), dtype=int64)
# tf.Tensor(83, shape=(), dtype=int64)
# tf.Tensor(19, shape=(), dtype=int64)


print("Saved to {}".format(mgr.latest_checkpoint))
for _ in range(3): print(next(it))
# Saved to /tmp/x/ckpt-1
# tf.Tensor(74, shape=(), dtype=int64)
# tf.Tensor(33, shape=(), dtype=int64)
# tf.Tensor(93, shape=(), dtype=int64)

print("Restored from {}".format(mgr.latest_checkpoint))
for _ in range(3): print(next(it))
# Restored from /tmp/x/ckpt-1
# tf.Tensor(74, shape=(), dtype=int64)
# tf.Tensor(33, shape=(), dtype=int64)
# tf.Tensor(93, shape=(), dtype=int64)

"""Dataset is not fully resumed"""
ds = tf.data.Dataset.range(100).shuffle(100, seed=42, reshuffle_each_iteration=True)
ckpt = tf.train.Checkpoint(foo=it)
mgr = tf.train.CheckpointManager(ckpt, '/tmp/x', max_to_keep=3)
enumerate(ds) # A1


enumerate(ds) # A2


enumerate(ds) # A1 again instead of A2, dataset starts over
Last Updated: 8/4/2022, 9:25:06 PM