# Backpropagation
# Clarification
# Routes of gradient
For each variable w, combine all routes' propagation => . For example, considering
the gradient has three routes: y0 => x0, y1 => x0, and y2 => x0, thus
import tensorflow as tf
with tf.GradientTape(persistent=True) as tape:
tape.watch(wk2)
tape.watch(wk)
tape.watch(wk3)
tape.watch(s3)
tape.watch(s)
q = tf.linalg.matmul(x1, wq)
k = tf.linalg.matmul(x1, wk)
k2 = tf.linalg.matmul(x2, wk2)
k3 = tf.linalg.matmul(x3, wk3)
s = tf.linalg.matmul(q, tf.transpose(k))
s2 = tf.linalg.matmul(q, tf.transpose(k2))
s3 = tf.linalg.matmul(q, tf.transpose(k3))
S = tf.concat([s, s2, s3-10000], axis=1)
y = tf.nn.softmax(S)[0]
y0 = tf.gather(y, 0)
y1 = tf.gather(y, 1)
y2 = tf.gather(y, 2)
tape.gradient(y, s) => 0
tape.gradient(y0, s) => some value
# routes to s3 all 0
tape.gradient(y0, s3) => 0
tape.gradient(y1, s3) => 0
tape.gradient(y2, s3) => 0
tape.gradient(y0, wk3) => 0 # lower gradients in chain also 0
# Some observations
- if x is 0, is 0
- considering , if is 0, is also 0
# Scalar level
References
# Multi-label classification
logistic activation, cross-entropy loss
Forward
the weighted input sum at hidden unit j:
logistic/sigmoid activation at unit j:
the weighed input sum at output unit i:
logistic/sigmoid activation at output i:
binary cross entropy error:
Backward
gives:
then:
# Multi-class classification
logistic activation at hidden layer, but softmax activation at output layer , cross-entropy loss
Forward
the weighted input sum at hidden unit j:
logistic/sigmoid activation at unit j:
the weighed input sum at output unit i:
softmax activation at output i:
cross entropy error:
Backward
# Control
# Update layer's trainable
: official doc
Modern Keras contains the following facilities to view and manipulate trainable state
# Print current trainable map:
print(model._get_trainable_state())
# Set every layer to be non-trainable:
# take care for excluding model wrappers(model_wrapper.trainable = False makes all the sub layers False)
for k,v in model._get_trainable_state().items():
k.trainable = False
# Don't forget to re-compile the model
model.compile(...)
Important notes about BatchNormalization layer
Many image models contain BatchNormalization
layers. That layer is a special case on every imaginable count. Here are a few things to keep in mind.
BatchNormalization
contains 2 non-trainable weights that get updated during training. These are the variables tracking the mean and variance of the inputs.
When you set bn_layer.trainable = False
, the BatchNormalization
layer will run in inference mode, and will not update its mean & variance statistics. This is not the case for other layers in general, as weight trainability & inference/training modes are two orthogonal concepts. But the two are tied in the case of the BatchNormalization
layer.
When you unfreeze a model that contains BatchNormalization
layers in order to do fine-tuning, you should keep the BatchNormalization
layers in inference mode by passing training=False
when calling the base model. Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model has learned.
# Create base model
base_model = keras.applications.Xception(
weights='imagenet',
input_shape=(150, 150, 3),
include_top=False)
# Freeze base model
base_model.trainable = False
# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
# Exclude variable by name
Sometimes, variables are included implicitly in a model
training_vars = compiled_model.trainable_variables
training_vars = [v for v in training_vars if v.name == 'some_name']
...
gradients = tape.gradient(loss, training_vars)
model.optimizer.apply_gradients(zip(gradients, training_vars))
# Exclude part of variable by tf.stop_gradient
Reference:
- https://stackoverflow.com/a/43368518/6845273
- https://www.tensorflow.org/api_docs/python/tf/stop_gradient