# Keras
Functional api writes model component line by line and then uses Model(input, output)
to get a encapsulated model. After building a model using functional api, the graph and weights are initialized and thus checkpoint can be loaded and checked immediately.
Module sub_classing encapsulated the model inherently, but weights and graphs are lazy-initialized, so model.summary()
or checkpoint.restore(ckpt).assert_consumed()
only works after model is called.
tf.keras.layers.Layer
inherits tf.Module
, and tf.keras.Model
inherits tf.keras.layers.Layer
. ._layers
is introduced in tf.keras.layers.Layer
and layers
is introduced in tf.keras.Model
. As a result, tf.keras.Model
has both layers
and ._layers
attribute whereas tf.keras.layers.Layer
only has ._layers
attribute.
tf.keras.layers.Layer
has a get_weights
method, by which a list of numpy values of all encapsulated variable weights can be obtained. Inside tf.keras.layers.Layer
, there can be tf.Variable
and other tf.keras.layers.Layer
# Model customization
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data # sample weight can be added here
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # additional loss added
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes all the metrics that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def call(inputs, *args, train=False, mask=None, **kwargs):
# args and kwargs are not recommended as poor hygiene check
# more at tf 2.1: https://github.com/tensorflow/tensorflow/blob/e5bf8de410005de06a7ff5393fafdf832ef1d4ad/tensorflow/python/keras/engine/base_layer.py#L628
# more at tf 2.3: https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/engine/base_layer.py#L875
pass
"""
Avoid declaration of tf.Variable inside model call. Model training is fine but export leads to error. Reason:
When a tf.Variable is created in a tf.function, it will be lifted to the default graph context for initialization.
However, if the variable has external dependency such as a placeholder input, the variable will fail to be lifted to
the default graph context. When one defines a sparse operation in their model, the gradient aggregation variable will fail to be initialized. More info: https://github.com/horovod/horovod/pull/3499
"""
See customized optimizer here
Valid Layers
Only layers exist in both __init__
/ build
and call
are effective ones (saved in checkpoint)
# Callbacks
Using callbacks during model.fit
has some restrictions
ProgbarLogger Callback:
- tf2.1: ProgbarLogger is the last callback and can only be muted but not customized if using
model.fit
- tf2.2: ProgbarLogger can be customized given
CallbackList
if usingmodel.fit
- ProgbarLogger is reset every epoch, so in batch logs have the same display
- ProgbarLogger averages the loss, but not other metrics(stateful)
Tensorboard Callback:
- Tensorboard callback always logs per epoch even if
update_freq=int
- if metric name starts with
val_
, Tensorboard callback logs the value using validation writer (logs in validation folder) - In batch, tensorboard callback has no validation writer
# Precision
In tf.keras.layers.Layer
constructor, there is dtype
argument which will cast inputs
in call(self, inputs, **kwargs)
to dtype
, if the inputs
is the same type as dtype
but precision is different.
# Shape
If dealing with unknown batch size while building the model graph, use tf.shape(variable_tensor)
instead of variable_tensor.shape
can solve the cannot convert a partially known tensorshape to a tensor
exception. It is because the former gives a tensor whereas the later returns TensorShape
← Data loading Serving →