Function approximation with neural networks

In the following we will provide a simple Tensorflow example that approximates a function through neural networks. We will start with a very basic setup and then, step by step, we introduce serveral techniques that speed up training and improve the result.

In [5]:
# some imports that we need for our example
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import init_ops
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.python.training.moving_averages import assign_moving_average
from tensorflow.contrib.layers.python.layers import utils
import matplotlib.pyplot as plt

In order to compute operations in Tensorflow, first a so-called computation graph is constructed by calling Tensorflow methods. The computation graph is a collection of concatenations of mathematical operations such as matrix multiplication, application of activation functions, (automatic) differentiation. After the Tensorflow graph is fixed, we can perform forward and backward computations along its edges.

In [6]:
def neural_net(x, name, num_neurons, activation_fn=tf.nn.relu, reuse=None, dtype=tf.float32):

    def fc_layer(_x, out_size, activation):
        shape = _x.get_shape().as_list()  # Shape of the input vector to the neural network
        w = tf.get_variable(
            name='weights',
            shape=[shape[-1], out_size],
            dtype=dtype,
            initializer=initializers.xavier_initializer())  # weight matrix
        b = tf.get_variable(
            name='bias',
            shape=[1, out_size],
            dtype=dtype,
            initializer=tf.zeros_initializer())  # bias vector
        return activation(tf.matmul(_x, w) + b)  # in this case: relu(Wx+b)

    with tf.variable_scope(name, reuse=reuse):  # we pass along the output of each layer as the input of the next
        for i in range(len(num_neurons)):
            with tf.variable_scope('layer_{0}'.format(i)):
                x = fc_layer(x, num_neurons[i],
                             activation_fn if i < len(num_neurons) - 1 else tf.identity)
        return x

The function neural_net takes a tensor x and returns f(x) a neural network $$f = A_k\circ relu\circ A_{k-1} \circ \cdots \circ relu \circ A_1$$ applied to x.

In [7]:
def fun(x):
    return tf.sin(x)

We will attempt to learn the non-linear function $x \mapsto sin(x)$ by means of labeled samples $(x_1,y_1), (x_2,y_2), \ldots$ with inputs $x_i$ and labels $y_i=\sin(x_i)$.

In [8]:
batch_size = 256  # in each optimization step we will move towards the negative mean gradient of batch_size many samples
train_steps = 5000  # and perform 5000 total optimization steps
x_min, x_max = 0., 10.  # we draw samples from within the interval [x_min, x_max] 

tf.reset_default_graph()  # remove all nodes and reset the computation graph to empty
x_input = tf.placeholder(tf.float32, shape=(None, 1))  # variable object which must be specified ('fed') during runtime


y_samples = fun(x_input)  # (operation to) create labels y (during runtime)
y_apx = neural_net(x_input, 'neural_net', [64, 64, 1])  # ... and the prediction y' of the neural net (during runtime)
loss = tf.reduce_mean((y_samples - y_apx) ** 2)  # squared l_2 distance between vectors y and y'

optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss)  # create optimization operator

Calling ('running') the train_op operator during a TF session computes all necessary elementary operations in the graph tracing back to x and updates all the trainable variables, i.e. those created by the get_variable function according as to the Adam optimization algorithm.

We will now instantiate a TF Session object, which enables running operations on the previously constructed graph. Within the Session we perform the training, i.e., a consecutive iteration of the minimization algorithm by means of train_op.

In [12]:
with tf.Session() as sess:
    # the initializer needs to be run in order to obtain values for each non-constant initialized variable
    sess.run(tf.global_variables_initializer())
    
    # perform training
    for step in range(train_steps):
        samples = np.random.uniform(low=x_min, high=x_max, size=(batch_size, 1))  # samples from numpy module
        _, l = sess.run([train_op, loss],  # simultaneously run each operation in the list at once
                        feed_dict={x_input: samples})  # feed the unknown variable x_input at runtime
        if (step + 1) % 100 == 0:
            print('step: {0}, loss {1}'.format(step, l))
    
    # create equidistant test grid X in [x_min, x_max] with labels {f(x): x in X}
    samples = np.linspace(x_min, x_max, 500)
    samples = np.reshape(samples, (500, 1))
    y_exact, y_test = sess.run([y_samples, y_apx], feed_dict={x_input: samples})
step: 99, loss 0.3944329619407654
step: 199, loss 0.32044145464897156
step: 299, loss 0.26136747002601624
step: 399, loss 0.2090146690607071
step: 499, loss 0.2022627741098404
step: 599, loss 0.1780681610107422
step: 699, loss 0.2091144323348999
step: 799, loss 0.19174236059188843
step: 899, loss 0.182851180434227
step: 999, loss 0.1770906150341034
step: 1099, loss 0.15998543798923492
step: 1199, loss 0.1557508111000061
step: 1299, loss 0.16371136903762817
step: 1399, loss 0.1480572521686554
step: 1499, loss 0.20495465397834778
step: 1599, loss 0.1997828483581543
step: 1699, loss 0.17092067003250122
step: 1799, loss 0.1895742416381836
step: 1899, loss 0.19024419784545898
step: 1999, loss 0.1762785166501999
step: 2099, loss 0.2372676283121109
step: 2199, loss 0.153299018740654
step: 2299, loss 0.1690065562725067
step: 2399, loss 0.2256040722131729
step: 2499, loss 0.17590314149856567
step: 2599, loss 0.14365652203559875
step: 2699, loss 0.18380023539066315
step: 2799, loss 0.1748751848936081
step: 2899, loss 0.18931184709072113
step: 2999, loss 0.1851882040500641
step: 3099, loss 0.1506672501564026
step: 3199, loss 0.13654044270515442
step: 3299, loss 0.10756970196962357
step: 3399, loss 0.08258842676877975
step: 3499, loss 0.03971070423722267
step: 3599, loss 0.019938023760914803
step: 3699, loss 0.009403139352798462
step: 3799, loss 0.0053341565653681755
step: 3899, loss 0.003660729853436351
step: 3999, loss 0.0025727790780365467
step: 4099, loss 0.0026825759559869766
step: 4199, loss 0.002505695214495063
step: 4299, loss 0.0019232844933867455
step: 4399, loss 0.0017673502443358302
step: 4499, loss 0.0020390197169035673
step: 4599, loss 0.0024595085997134447
step: 4699, loss 0.0021815067157149315
step: 4799, loss 0.0018529579974710941
step: 4899, loss 0.0014773695729672909
step: 4999, loss 0.0013659417163580656

Let's take a look at the mapping learned by the neural network from fitting to the training samples.

In [13]:
plt.plot(samples, y_exact)
plt.plot(samples, y_test)
plt.show()

In the next section we will replace the tf.placeholder() which is fed during runtime with a tf.random_uniform() tensor. The advantage is that the function of the tensor is known during construction of the graph and once initialized at runtime the random numbers will be generated on the GPU resulting in less input/output operations and thus faster runtime.

In [14]:
batch_size = 256
train_steps = 5000
x_min, x_max = 0., 10.

tf.reset_default_graph()

x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))

y_samples = fun(x_input)
y_apx = neural_net(x_input, 'neural_net', [64, 64, 1])

loss = tf.reduce_mean((y_samples - y_apx) ** 2)

optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss)

x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE)
In [15]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(train_steps):
        _, l = sess.run([train_op, loss])
        if (step + 1) % 100 == 0:
            print('step: {0}, loss {1}'.format(step, l))
            
    x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
    plt.plot(x_test, y_exact)
    plt.plot(x_test, y_test)
    plt.show()
step: 99, loss 0.45837071537971497
step: 199, loss 0.37495774030685425
step: 299, loss 0.3075532913208008
step: 399, loss 0.21496528387069702
step: 499, loss 0.1830555945634842
step: 599, loss 0.16028404235839844
step: 699, loss 0.1327740103006363
step: 799, loss 0.14208462834358215
step: 899, loss 0.13480737805366516
step: 999, loss 0.1035740077495575
step: 1099, loss 0.07661621272563934
step: 1199, loss 0.04826673865318298
step: 1299, loss 0.05924380198121071
step: 1399, loss 0.0401323176920414
step: 1499, loss 0.017693253234028816
step: 1599, loss 0.019206056371331215
step: 1699, loss 0.012577931396663189
step: 1799, loss 0.007105778902769089
step: 1899, loss 0.005776992067694664
step: 1999, loss 0.0047250292263925076
step: 2099, loss 0.004326990339905024
step: 2199, loss 0.0033508935011923313
step: 2299, loss 0.005356467794626951
step: 2399, loss 0.0036242124624550343
step: 2499, loss 0.0029315263964235783
step: 2599, loss 0.003069708589464426
step: 2699, loss 0.004951186943799257
step: 2799, loss 0.0034277255181223154
step: 2899, loss 0.003175192978233099
step: 2999, loss 0.0034575401805341244
step: 3099, loss 0.004660377278923988
step: 3199, loss 0.0033249875996261835
step: 3299, loss 0.002721356460824609
step: 3399, loss 0.004368613939732313
step: 3499, loss 0.004228290170431137
step: 3599, loss 0.003654658328741789
step: 3699, loss 0.0025696323718875647
step: 3799, loss 0.003579198382794857
step: 3899, loss 0.003729198593646288
step: 3999, loss 0.003124064765870571
step: 4099, loss 0.002153692301362753
step: 4199, loss 0.0034418883733451366
step: 4299, loss 0.0030124341137707233
step: 4399, loss 0.0038799052126705647
step: 4499, loss 0.0026746010407805443
step: 4599, loss 0.00443761982023716
step: 4699, loss 0.003596237860620022
step: 4799, loss 0.0029312074184417725
step: 4899, loss 0.0034144483506679535
step: 4999, loss 0.0042300717905163765

When working with the relu activation function and random normal initialized weight matrices it is very likely that a large fraction of the hidden nodes yield a flat zero gradient. To resolve this issue and accelerate training we perform a transformation to each batch pre-activation which distributes the values to be activated concentrated around zero, which is crucial for effective gradient computations. This effect can be observed more severely if we change the interval from which we sampled to large numbers, for example [100, 110].

In [16]:
def neural_net_bn(x, name, num_neurons, is_training, 
                  activation_fn=tf.nn.relu, reuse=None, 
                  decay=0.9, dtype=tf.float32):

    # we define the batch-normalization layer in an elementary fashion
    def batch_normalization(_x):
        shape = _x.get_shape().as_list()
        beta = tf.get_variable(name='beta', shape=[shape[-1]], dtype=dtype,
                               initializer=init_ops.zeros_initializer())
        gamma = tf.get_variable(name='gamma', shape=[shape[-1]], dtype=dtype,
                                initializer=init_ops.ones_initializer())
        mv_mean = tf.get_variable('mv_mean', [shape[-1]], dtype=dtype,
                                  initializer=init_ops.zeros_initializer(), 
                                  trainable=False)
        mv_var = tf.get_variable('mv_var', [shape[-1]], dtype=dtype,
                                 initializer=init_ops.zeros_initializer(), 
                                 trainable=False)
        mean, variance = tf.nn.moments(_x, [0], name='moments')
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, 
                             assign_moving_average(mv_mean, mean, decay, 
                                                   zero_debias=True))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, 
                             assign_moving_average(mv_var, variance, decay, 
                                                   zero_debias=True))
        mean, variance = utils.smart_cond(is_training, 
                                          lambda: (mean, variance), 
                                          lambda: (mv_mean, mv_var))
        return tf.nn.batch_normalization(_x, mean, variance, beta, gamma, 1e-6)

    def fc_layer(_x, out_size, activation):
        shape = _x.get_shape().as_list()
        w = tf.get_variable(
            name='weights',
            shape=[shape[-1], out_size],
            dtype=dtype,
            initializer=initializers.xavier_initializer())
        return activation(batch_normalization(tf.matmul(_x, w)))

    with tf.variable_scope(name, reuse=reuse):
        x = batch_normalization(x)
        for i in range(len(num_neurons)):
            with tf.variable_scope('layer_{0}'.format(i)):
                x = fc_layer(x, num_neurons[i],
                             activation_fn if i < len(num_neurons) - 1 else tf.identity)
        return x

When using batch normalization we need to keep track of the parameters \alpha and \beta learned during training. To this end, we perform the optimization within tf.control_dependencies() of the graph variables.

In [74]:
batch_size = 8192
train_steps = 2000
x_min, x_max = 0., 10.

tf.reset_default_graph()

x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))

y_samples = fun(x_input)
y_apx = neural_net_bn(x_input, 'neural_net', [64, 64, 1], is_training=True)

loss = tf.reduce_mean((y_samples - y_apx) ** 2)

optimizer = tf.train.AdamOptimizer()
# keep track of alpha and beta
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    train_op = optimizer.minimize(loss)

x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net_bn(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE, is_training=False)
In [75]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(train_steps):
        _, l = sess.run([train_op, loss])
        if (step + 1) % 100 == 0:
            print('step: {0}, loss {1}'.format(step, l))
            
    x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
    plt.plot(x_test, y_exact)
    plt.plot(x_test, y_test)
    plt.show()
step: 99, loss 0.04594836011528969
step: 199, loss 0.013351433910429478
step: 299, loss 0.003041457384824753
step: 399, loss 0.0018526293570175767
step: 499, loss 0.0015206707175821066
step: 599, loss 0.0048472522757947445
step: 699, loss 0.0025380728766322136
step: 799, loss 0.0004803991469088942
step: 899, loss 0.004159579984843731
step: 999, loss 0.003069471102207899
step: 1099, loss 0.001511941198259592
step: 1199, loss 0.0008276765584014356
step: 1299, loss 0.008634578436613083
step: 1399, loss 0.0011270123068243265
step: 1499, loss 0.0015029151691123843
step: 1599, loss 0.002813320141285658
step: 1699, loss 0.003993199672549963
step: 1799, loss 0.001451092422939837
step: 1899, loss 0.00025049239047802985
step: 1999, loss 0.0017990634078159928

Finally, we provide an explicit training schedule for the learning rate used by the Adam optimizer. This further speeds up the training as can be observed below.

In [76]:
batch_size = 8192
train_steps = 1000
x_min, x_max = 0., 10.

# declaration of the training schedule
lr_boundaries = [200, 500]  @
lr_values = [0.1, 0.01, 0.001]

tf.reset_default_graph()

x_input = tf.random_uniform(minval=x_min, maxval=x_max, shape=(batch_size, 1))

y_samples = fun(x_input)
y_apx = neural_net_bn(x_input, 'neural_net', [64, 64, 1], is_training=True)

loss = tf.reduce_mean((y_samples - y_apx) ** 2)

global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.piecewise_constant(global_step, lr_boundaries, lr_values)
optimizer = tf.train.AdamOptimizer(learning_rate)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    train_op = optimizer.minimize(loss, global_step=global_step)

x_eval = tf.linspace(x_min, x_max, 500)
x_eval = tf.reshape(x_eval, (500, 1))
y_fun = fun(x_eval)
y_eval = neural_net_bn(x_eval, 'neural_net', [64, 64, 1], reuse=tf.AUTO_REUSE, is_training=False)
In [77]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(train_steps):
        _, l = sess.run([train_op, loss])
        if (step + 1) % 100 == 0:
            print('step: {0}, loss {1}'.format(step, l))
            
    x_test, y_exact, y_test = sess.run([x_eval, y_fun, y_eval])
    plt.plot(x_test, y_exact)
    plt.plot(x_test, y_test)
    plt.show()
step: 99, loss 0.0009860822465270758
step: 199, loss 0.0020077507942914963
step: 299, loss 0.00032659099088050425
step: 399, loss 0.0005868814769200981
step: 499, loss 0.0008676070137880743
step: 599, loss 0.0002522874856367707
step: 699, loss 0.0005854364717379212
step: 799, loss 0.00023204261378850788
step: 899, loss 0.00019143737154081464
step: 999, loss 0.0004034672165289521