# Checkpoint

tf 2 checkpoint is object based.

# tf.train.Checkpoint

# Basics

save and restore are used in pair, while read and write are used in pair. For the first pair, they maintain a save_counter which is used to name ckpt files. Regards with the last pair, naming of ckpt should be self-implemented.

restore or read returns a status object, which have the following four check mechanism:

  • assert_consumed: all checkpoint variables are restored into model and there are no additional model variable which has no mapping
  • assert_existing_objects_matched: every built model variable must exist in the transitive dependencies(children or grandchildren node) in checkpoint graph.
    • if model is yet to be built, this check will pass too.
    • Any unrestored variable in checkpoint will raise warnings.
  • assert_nontrivial_match: there must exist some variables in model which exist in the transitive dependencies(children or grandchildren node) in checkpoint graph
    • For restore and save pair, this check will always pass due to save_counter.
    • Any unrestored variable in checkpoint will raise warnings.
  • expect_partial: no check and no warnings

Before tensorflow 2.4, checkpoint saved using model.save_weights cannot be loaded with tf.train.Checkpoint(key=model) because:

After tensorflow 2.4, root keyword argument replaced the position argument and:

# Variable Names

Variable name serves as the edge in graph and it is important to understand its generation mechanism

When a checkpoint is saved, variable label is switched first before saving. For example:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu, name='heihei')
     self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)

  def call(self, inputs):
    x = self.dense1(inputs)

r = Regress()
r(some_input)

# if no name provided, class name is used; if class names conflict, _x number is added as a suffix
# variable name of dense1: regress/dense/kernel:0
# variable name of dense2: regress/heihei/kernel:0
# variable name of dense3: regress/dense_1/kernel:0
r.variables

checkpoint_directory = './training_checkpoints'
checkpoint_prefix = pathlib.Path(checkpoint_directory) / 'ckpt'

# checkpoint use different label naming mechanism
# variable name in python code and argument passed are used
# variable name of dense1: zy/dense1/kernel/.ATTRIBUTES/VARIABLE_VALUE
# variable name of dense2: zy/dense2/kernel/.ATTRIBUTES/VARIABLE_VALUE
# variable name of dense3: zy/dense3/kernel/.ATTRIBUTES/VARIABLE_VALUE
checkpoint = tf.train.Checkpoint(zy=r)
checkpoint.save(file_prefix=checkpoint_prefix)

# Practices

Example - save checkpoint in Regress model and restore it in RegressToRestore:

"""
Compile Fit Save

Build (Compile) Load (Compile) Fit
Delayed Restore
Build first to create internal variables
"""

import tensorflow as tf
import pathlib

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax, input_shape=(4,))

    # ...

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  def dummy_input(self):
    return tf.constant([[1.]])

model = Regress()

opt = tf.keras.optimizers.Adam(0.1)
model.compile(loss="mean_squared_error", optimizer=opt)


example_x = tf.constant([[1.]])
example_y = tf.constant([[1.,2.,3.,4.,5.]])
model.fit(example_x, example_y, epochs=1)


checkpoint_directory = './training_checkpoints'
checkpoint_prefix = pathlib.Path(checkpoint_directory) / 'ckpt'

# label in ckpt => variable
# regress/dense/kernel:0 => zy/kernel/.ATTRIBUTES/VARIABLE_VALUE
# [('_CHECKPOINTABLE_OBJECT_GRAPH', []), ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', [])]
checkpoint = tf.train.Checkpoint(zy=model.dense1, opt=opt)
checkpoint.save(file_prefix=checkpoint_prefix)


class RegressToRestore(tf.keras.Model):

  def __init__(self):
    super(RegressToRestore, self).__init__()
    self.dense3 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense4 = tf.keras.layers.Dense(5, activation=tf.nn.softmax, input_shape=(4,))

    # ...

  def call(self, inputs):
    x = self.dense3(inputs)
    return self.dense4(x)

  def dummy_input(self):
    return tf.constant([[1.]])


model = RegressToRestore()

# model.build(input_shape=(None,1))
model(model.dummy_input())


opt = tf.keras.optimizers.Adam(0.1)
# same argument name => zy/kernel/.ATTRIBUTES/VARIABLE_VALUE => graph matches
checkpoint = tf.train.Checkpoint(zy=model.dense3, opt=opt)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))


model.compile(loss="mean_squared_error", optimizer=opt)
example_x = tf.constant([[1.]])
example_y = tf.constant([[1.,2.,3.,4.,5.]])
model.fit(example_x, example_y, epochs=1)


status.assert_consumed()

Remark:

  • To load a model from checkpoint, the model must be built first.

  • To mute the warnings Unresolved object in checkpoint: (root).optimizer.beta_1, the model must be fit or use expect_partial().

  • To read variables inside a checkpoint:

    tf.train.list_variables(ckpt_path)
    

There are two ways to load a sub graph by tf.train.Checkpoint:

  • save full model in checkpoint and embed the sub graph to be restored in a fake structure
  • save both full and sub graph in checkpoint

No matter which method, without calling expect_partial, there might be warnings related to unresolved variables in checkpoint

Example - save checkpoint in Net and restore a sub graph only:

# Model
class Net(tf.keras.Model):
"""A simple linear model."""
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)

    def call(self, x):
        return self.l1(x)

net = Net()

# Way 1: embed the sub graph to be restored in a fake structure
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
manager.save()

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros

# Fake structure
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))

Another method:

# Way 2: save both full model and sub graph in checkpoint
# Note that variable label is slightly altered. However, this does not affect full model restoration
# net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE, full model variable outside sub graph
# bias/.ATTRIBUTES/VARIABLE_VALUE, variable in sub graph
ckpt = tf.train.Checkpoint(net=net, bias=net.l1.bias)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
manager.save()

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
restore_ckpt = tf.train.Checkpoint(bias=to_restore)
status = restore_ckpt.restore(tf.train.latest_checkpoint('./tf_ckpts/'))

net2 = Net()
restore_ckpt = tf.train.Checkpoint(net=net2)  # full model can also be restored

Example - Adding a new layer in existing model

tensorflow checkpoint restoration is smart enough. If model adds an additional layer, tensorflow can find the partial structure to load the previous saved weights.

# restore weight
new_model = build_new_model()  # additional sub_layer3 is added
latest_check_point = tf.train.latest_checkpoint('old_ckpt')  # model with two sub_layer
ckpt = tf.train.Checkpoint(**{'model': new_model})
ckpt.restore(latest_check_point)

# tf.train.CheckpointManager

Manages multiple checkpoints by keeping some and deleting unneeded ones.

import tensorflow as tf
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
    checkpoint, directory="/tmp/model", max_to_keep=5)
status = checkpoint.restore(manager.latest_checkpoint)

# tf.train.load_checkpoint

returns CheckpointReader, which can get variable_to_shape map and variable_to_dtype map

# tf.keras.Model

save_weights(
    filepath, overwrite=True, save_format=None, options=None
)

load_weights(
    filepath, by_name=False, skip_mismatch=False, options=None
)
Last Updated: 8/4/2022, 9:25:06 PM