# Keras

Functional api writes model component line by line and then uses Model(input, output) to get a encapsulated model. After building a model using functional api, the graph and weights are initialized and thus checkpoint can be loaded and checked immediately.

Module sub_classing encapsulated the model inherently, but weights and graphs are lazy-initialized, so model.summary() or checkpoint.restore(ckpt).assert_consumed() only works after model is called.

tf.keras.layers.Layer inherits tf.Module, and tf.keras.Model inherits tf.keras.layers.Layer. ._layers is introduced in tf.keras.layers.Layer and layers is introduced in tf.keras.Model. As a result, tf.keras.Model has both layers and ._layers attribute whereas tf.keras.layers.Layer only has ._layers attribute.

tf.keras.layers.Layer has a get_weights method, by which a list of numpy values of all encapsulated variable weights can be obtained. Inside tf.keras.layers.Layer, there can be tf.Variable and other tf.keras.layers.Layer

custom training logic(from tf2.2)

See resource 1

See resource 2

# Model customization

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data # sample weight can be added here

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # additional loss added

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes all the metrics that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    def call(inputs, *args, train=False, mask=None, **kwargs):
        # args and kwargs are not recommended as poor hygiene check
        # more at tf 2.1: https://github.com/tensorflow/tensorflow/blob/e5bf8de410005de06a7ff5393fafdf832ef1d4ad/tensorflow/python/keras/engine/base_layer.py#L628
        # more at tf 2.3: https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/engine/base_layer.py#L875
        pass
        """
        Avoid declaration of tf.Variable inside model call. Model training is fine but export leads to error. Reason:
        When a tf.Variable is created in a tf.function, it will be lifted to the default graph context for initialization.
        However, if the variable has external dependency such as a placeholder input, the variable will fail to be lifted to
        the default graph context. When one defines a sparse operation in their model, the gradient aggregation variable will fail to be initialized. More info: https://github.com/horovod/horovod/pull/3499
        """

See customized optimizer here

Valid Layers

Only layers exist in both __init__ / build and call are effective ones (saved in checkpoint)

# Callbacks

Using callbacks during model.fit has some restrictions

ProgbarLogger Callback:

  • tf2.1: ProgbarLogger is the last callback and can only be muted but not customized if using model.fit
  • tf2.2: ProgbarLogger can be customized given CallbackList if using model.fit
  • ProgbarLogger is reset every epoch, so in batch logs have the same display
  • ProgbarLogger averages the loss, but not other metrics(stateful)

Tensorboard Callback:

  • Tensorboard callback always logs per epoch even if update_freq=int
  • if metric name starts with val_, Tensorboard callback logs the value using validation writer (logs in validation folder)
  • In batch, tensorboard callback has no validation writer

# Precision

In tf.keras.layers.Layer constructor, there is dtype argument which will cast inputs in call(self, inputs, **kwargs) to dtype, if the inputs is the same type as dtype but precision is different.

# Shape

If dealing with unknown batch size while building the model graph, use tf.shape(variable_tensor) instead of variable_tensor.shape can solve the cannot convert a partially known tensorshape to a tensor exception. It is because the former gives a tensor whereas the later returns TensorShape

Last Updated: 8/4/2022, 9:25:06 PM