In the following we will provide a simple Tensorflow example that approximates a function through neural networks. We will start with a very basic setup and then, step by step, we introduce serveral techniques that speed up training and improve the result.
# some imports that we need for our example
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import init_ops
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.python.training.moving_averages import assign_moving_average
from tensorflow.contrib.layers.python.layers import utils
import matplotlib.pyplot as plt
In order to compute operations in Tensorflow, first a so-called computation graph is constructed by calling Tensorflow methods. The computation graph is a collection of concatenations of mathematical operations such as matrix multiplication, application of activation functions, (automatic) differentiation. After the Tensorflow graph is fixed, we can perform forward and backward computations along its edges.
def neural_net(x, name, num_neurons, activation_fn=tf.nn.relu, reuse=None, dtype=tf.float32):
def fc_layer(_x, out_size, activation):
shape = _x.get_shape().as_list() # Shape of the input vector to the neural network
w = tf.get_variable(
name='weights',
shape=[shape[-1], out_size],
dtype=dtype,
initializer=initializers.xavier_initializer()) # weight matrix
b = tf.get_variable(
name='bias',
shape=[1, out_size],
dtype=dtype,
initializer=tf.zeros_initializer()) # bias vector
return activation(tf.matmul(_x, w) + b) # in this case: relu(Wx+b)
with tf.variable_scope(name, reuse=reuse): # we pass along the output of each layer as the input of the next
for i in range(len(num_neurons)):
with tf.variable_scope('layer_{0}'.format(i)):
x = fc_layer(x, num_neurons[i],
activation_fn if i < len(num_neurons) - 1 else tf.identity)
return x
The function neural_net
takes a tensor x
and returns f(x)
a neural network
$$f = A_k\circ relu\circ A_{k-1} \circ \cdots \circ relu \circ A_1$$
applied to x
.
def fun(x):
return tf.sin(x)
We will attempt to learn the non-linear function $x \mapsto sin(x)$ by means of labeled samples $(x_1,y_1), (x_2,y_2), \ldots$ with inputs $x_i$ and labels $y_i=\sin(x_i)$.
batch_size = 256 # in each optimization step we will move towards the negative mean gradient of batch_size many samples
train_steps = 5000 # and perform 5000 total optimization steps
x_min, x_max = 0., 10. # we draw samples from within the interval [x_min, x_max]
tf.reset_default_graph() # remove all nodes and reset the computation graph to empty
x_input = tf.placeholder(tf.float32, shape=(None, 1)) # variable object which must be specified ('fed') during runtime
y_samples = fun(x_input) # (operation to) create labels y (during runtime)
y_apx = neural_net(x_input, 'neural_net', [64, 64, 1]) # ... and the prediction y' of the neural net (during runtime)
loss = tf.reduce_mean((y_samples - y_apx) ** 2) # squared l_2 distance between vectors y and y'
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss) # create optimization operator
Calling ('running') the train_op
operator during a TF session computes all necessary elementary operations in the graph tracing back to x
and updates all the trainable variables, i.e. those created by the get_variable
function according as to the Adam optimization algorithm.
We will now instantiate a TF Session object, which enables running operations on the previously constructed graph. Within the Session we perform the training, i.e., a consecutive iteration of the minimization algorithm by means of train_op
.
with tf.Session() as sess:
# the initializer needs to be run in order to obtain values for each non-constant initialized variable
sess.run(tf.global_variables_initializer())
# perform training
for step in range(train_steps):
samples = np.random.uniform(low=x_min, high=x_max, size=(batch_size, 1)) # samples from numpy module
_, l = sess.run([train_op, loss], # simultaneously run each operation in the list at once
feed_dict={x_input: samples}) # feed the unknown variable x_input at runtime
if (step + 1) % 100 == 0:
print('step: {0}, loss {1}'.format(step, l))
# create equidistant test grid X in [x_min, x_max] with labels {f(x): x in X}
samples = np.linspace(x_min, x_max, 500)
samples = np.reshape(samples, (500, 1))
y_exact, y_test = sess.run([y_samples, y_apx], feed_dict={x_input: samples})
Let's take a look at the mapping learned by the neural network from fitting to the training samples.
plt.plot(samples, y_exact)
plt.plot(samples, y_test)
plt.show()
In the next section we will replace the tf.placeholder()
which is fed during runtime with a tf.random_uniform()
tensor. The advantage is that the function of the tensor is known during construction of the graph and once initialized at runtime the random numbers will be generated on the GPU resulting in less input/output operations and thus faster runtime.
batch_size = 256
train_steps = 5000
x_min, x_max = 0., 10.
tf.reset_default_graph()
x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))
y_samples = fun(x_input)
y_apx = neural_net(x_input, 'neural_net', [64, 64, 1])
loss = tf.reduce_mean((y_samples - y_apx) ** 2)
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss)
x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(train_steps):
_, l = sess.run([train_op, loss])
if (step + 1) % 100 == 0:
print('step: {0}, loss {1}'.format(step, l))
x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
plt.plot(x_test, y_exact)
plt.plot(x_test, y_test)
plt.show()
When working with the relu
activation function and random normal initialized weight matrices it is very likely that a large fraction of the hidden nodes yield a flat zero gradient. To resolve this issue and accelerate training we perform a transformation to each batch pre-activation which distributes the values to be activated concentrated around zero, which is crucial for effective gradient computations. This effect can be observed more severely if we change the interval from which we sampled to large numbers, for example [100, 110].
def neural_net_bn(x, name, num_neurons, is_training,
activation_fn=tf.nn.relu, reuse=None,
decay=0.9, dtype=tf.float32):
# we define the batch-normalization layer in an elementary fashion
def batch_normalization(_x):
shape = _x.get_shape().as_list()
beta = tf.get_variable(name='beta', shape=[shape[-1]], dtype=dtype,
initializer=init_ops.zeros_initializer())
gamma = tf.get_variable(name='gamma', shape=[shape[-1]], dtype=dtype,
initializer=init_ops.ones_initializer())
mv_mean = tf.get_variable('mv_mean', [shape[-1]], dtype=dtype,
initializer=init_ops.zeros_initializer(),
trainable=False)
mv_var = tf.get_variable('mv_var', [shape[-1]], dtype=dtype,
initializer=init_ops.zeros_initializer(),
trainable=False)
mean, variance = tf.nn.moments(_x, [0], name='moments')
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
assign_moving_average(mv_mean, mean, decay,
zero_debias=True))
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
assign_moving_average(mv_var, variance, decay,
zero_debias=True))
mean, variance = utils.smart_cond(is_training,
lambda: (mean, variance),
lambda: (mv_mean, mv_var))
return tf.nn.batch_normalization(_x, mean, variance, beta, gamma, 1e-6)
def fc_layer(_x, out_size, activation):
shape = _x.get_shape().as_list()
w = tf.get_variable(
name='weights',
shape=[shape[-1], out_size],
dtype=dtype,
initializer=initializers.xavier_initializer())
return activation(batch_normalization(tf.matmul(_x, w)))
with tf.variable_scope(name, reuse=reuse):
x = batch_normalization(x)
for i in range(len(num_neurons)):
with tf.variable_scope('layer_{0}'.format(i)):
x = fc_layer(x, num_neurons[i],
activation_fn if i < len(num_neurons) - 1 else tf.identity)
return x
When using batch normalization we need to keep track of the parameters \alpha and \beta learned during training. To this end, we perform the optimization within tf.control_dependencies()
of the graph variables.
batch_size = 8192
train_steps = 2000
x_min, x_max = 0., 10.
tf.reset_default_graph()
x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))
y_samples = fun(x_input)
y_apx = neural_net_bn(x_input, 'neural_net', [64, 64, 1], is_training=True)
loss = tf.reduce_mean((y_samples - y_apx) ** 2)
optimizer = tf.train.AdamOptimizer()
# keep track of alpha and beta
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_op = optimizer.minimize(loss)
x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net_bn(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE, is_training=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(train_steps):
_, l = sess.run([train_op, loss])
if (step + 1) % 100 == 0:
print('step: {0}, loss {1}'.format(step, l))
x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
plt.plot(x_test, y_exact)
plt.plot(x_test, y_test)
plt.show()
Finally, we provide an explicit training schedule for the learning rate used by the Adam optimizer. This further speeds up the training as can be observed below.
batch_size = 8192
train_steps = 1000
x_min, x_max = 0., 10.
# declaration of the training schedule
lr_boundaries = [200, 500] @
lr_values = [0.1, 0.01, 0.001]
tf.reset_default_graph()
x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))
y_samples = fun(x_input)
y_apx = neural_net_bn(x_input, 'neural_net', [64, 64, 1], is_training=True)
loss = tf.reduce_mean((y_samples - y_apx) ** 2)
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.piecewise_constant(global_step, lr_boundaries, lr_values)
optimizer = tf.train.AdamOptimizer(learning_rate)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_op = optimizer.minimize(loss, global_step=global_step)
x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net_bn(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE, is_training=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(train_steps):
_, l = sess.run([train_op, loss])
if (step + 1) % 100 == 0:
print('step: {0}, loss {1}'.format(step, l))
x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
plt.plot(x_test, y_exact)
plt.plot(x_test, y_test)
plt.show()