# Checkpoint
tf 2 checkpoint is object based.
# tf.train.Checkpoint
# Basics
save
and restore
are used in pair, while read
and write
are used in pair. For the first pair, they maintain a save_counter
which is used to name ckpt files. Regards with the last pair, naming of ckpt should be self-implemented.
restore
or read
returns a status object, which have the following four check mechanism:
assert_consumed
: all checkpoint variables are restored into model and there are no additional model variable which has no mappingassert_existing_objects_matched
: every built model variable must exist in the transitive dependencies(children or grandchildren node) in checkpoint graph.- if model is yet to be built, this check will pass too.
- Any unrestored variable in checkpoint will raise warnings.
assert_nontrivial_match
: there must exist some variables in model which exist in the transitive dependencies(children or grandchildren node) in checkpoint graph- For
restore
andsave
pair, this check will always pass due tosave_counter
. - Any unrestored variable in checkpoint will raise warnings.
- For
expect_partial
: no check and no warnings
Before tensorflow 2.4, checkpoint saved using model.save_weights
cannot be loaded with tf.train.Checkpoint(key=model)
because:
After tensorflow 2.4, root
keyword argument replaced the position argument and:
# Variable Names
Variable name serves as the edge in graph and it is important to understand its generation mechanism
When a checkpoint is saved, variable label is switched first before saving. For example:
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu, name='heihei')
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
def call(self, inputs):
x = self.dense1(inputs)
r = Regress()
r(some_input)
# if no name provided, class name is used; if class names conflict, _x number is added as a suffix
# variable name of dense1: regress/dense/kernel:0
# variable name of dense2: regress/heihei/kernel:0
# variable name of dense3: regress/dense_1/kernel:0
r.variables
checkpoint_directory = './training_checkpoints'
checkpoint_prefix = pathlib.Path(checkpoint_directory) / 'ckpt'
# checkpoint use different label naming mechanism
# variable name in python code and argument passed are used
# variable name of dense1: zy/dense1/kernel/.ATTRIBUTES/VARIABLE_VALUE
# variable name of dense2: zy/dense2/kernel/.ATTRIBUTES/VARIABLE_VALUE
# variable name of dense3: zy/dense3/kernel/.ATTRIBUTES/VARIABLE_VALUE
checkpoint = tf.train.Checkpoint(zy=r)
checkpoint.save(file_prefix=checkpoint_prefix)
# Practices
Example - save checkpoint in Regress
model and restore it in RegressToRestore
:
"""
Compile Fit Save
Build (Compile) Load (Compile) Fit
Delayed Restore
Build first to create internal variables
"""
import tensorflow as tf
import pathlib
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax, input_shape=(4,))
# ...
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def dummy_input(self):
return tf.constant([[1.]])
model = Regress()
opt = tf.keras.optimizers.Adam(0.1)
model.compile(loss="mean_squared_error", optimizer=opt)
example_x = tf.constant([[1.]])
example_y = tf.constant([[1.,2.,3.,4.,5.]])
model.fit(example_x, example_y, epochs=1)
checkpoint_directory = './training_checkpoints'
checkpoint_prefix = pathlib.Path(checkpoint_directory) / 'ckpt'
# label in ckpt => variable
# regress/dense/kernel:0 => zy/kernel/.ATTRIBUTES/VARIABLE_VALUE
# [('_CHECKPOINTABLE_OBJECT_GRAPH', []), ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', [])]
checkpoint = tf.train.Checkpoint(zy=model.dense1, opt=opt)
checkpoint.save(file_prefix=checkpoint_prefix)
class RegressToRestore(tf.keras.Model):
def __init__(self):
super(RegressToRestore, self).__init__()
self.dense3 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense4 = tf.keras.layers.Dense(5, activation=tf.nn.softmax, input_shape=(4,))
# ...
def call(self, inputs):
x = self.dense3(inputs)
return self.dense4(x)
def dummy_input(self):
return tf.constant([[1.]])
model = RegressToRestore()
# model.build(input_shape=(None,1))
model(model.dummy_input())
opt = tf.keras.optimizers.Adam(0.1)
# same argument name => zy/kernel/.ATTRIBUTES/VARIABLE_VALUE => graph matches
checkpoint = tf.train.Checkpoint(zy=model.dense3, opt=opt)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
model.compile(loss="mean_squared_error", optimizer=opt)
example_x = tf.constant([[1.]])
example_y = tf.constant([[1.,2.,3.,4.,5.]])
model.fit(example_x, example_y, epochs=1)
status.assert_consumed()
Remark:
To load a model from checkpoint, the model must be built first.
To mute the warnings
Unresolved object in checkpoint: (root).optimizer.beta_1
, the model must be fit or useexpect_partial()
.To read variables inside a checkpoint:
tf.train.list_variables(ckpt_path)
There are two ways to load a sub graph by tf.train.Checkpoint
:
- save full model in checkpoint and embed the sub graph to be restored in a fake structure
- save both full and sub graph in checkpoint
No matter which method, without calling expect_partial
, there might be warnings related to unresolved variables in checkpoint
Example - save checkpoint in Net
and restore a sub graph only:
# Model
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
# Way 1: embed the sub graph to be restored in a fake structure
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
manager.save()
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
# Fake structure
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
Another method:
# Way 2: save both full model and sub graph in checkpoint
# Note that variable label is slightly altered. However, this does not affect full model restoration
# net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE, full model variable outside sub graph
# bias/.ATTRIBUTES/VARIABLE_VALUE, variable in sub graph
ckpt = tf.train.Checkpoint(net=net, bias=net.l1.bias)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
manager.save()
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
restore_ckpt = tf.train.Checkpoint(bias=to_restore)
status = restore_ckpt.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
net2 = Net()
restore_ckpt = tf.train.Checkpoint(net=net2) # full model can also be restored
Example - Adding a new layer in existing model
tensorflow checkpoint restoration is smart enough. If model adds an additional layer, tensorflow can find the partial structure to load the previous saved weights.
# restore weight
new_model = build_new_model() # additional sub_layer3 is added
latest_check_point = tf.train.latest_checkpoint('old_ckpt') # model with two sub_layer
ckpt = tf.train.Checkpoint(**{'model': new_model})
ckpt.restore(latest_check_point)
# tf.train.CheckpointManager
Manages multiple checkpoints by keeping some and deleting unneeded ones.
import tensorflow as tf
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint, directory="/tmp/model", max_to_keep=5)
status = checkpoint.restore(manager.latest_checkpoint)
# tf.train.load_checkpoint
returns CheckpointReader, which can get variable_to_shape map and variable_to_dtype map
# tf.keras.Model
save_weights(
filepath, overwrite=True, save_format=None, options=None
)
load_weights(
filepath, by_name=False, skip_mismatch=False, options=None
)
← Eager mode Function →