First lab assignment in Deep learning

Multi layer network gradients, introduction to PyTorch, fully connected deep models, case study: MNIST

The topic of this exercise are deep models for classification. We'll show that these models can be viewed either as multilayer feedforward neural networks or as augmented logistic regression models which were introduced in exercise zero. Both views lead to the same implementation which iteratively optimizes the likelihood of model parameters. To make development easier and to speed up the experiments, we'll study automatic differentiation opportunities offered by several numeric optimization frameworks. Special attention will be given to Pytorch which is one of the most frequently used tools in this category.

The goal of this exercise is to develop seven modules: data, fcann2, pt_linreg, pt_logreg, pt_deep, ksvm_wrap, and mnist_shootout. The module data will be an upgraded version of the corresponding module from lab 0. The module fcann2 will contain the implementation of a two-layer fully-connected model implemented in terms of Numpy primitives. This module should be very similar to the module logreg from lab 0. Modules pt_linreg pt_logreg and pt_deep will contain Pytorch implementations of three machine learning algorithms with increasing complexity. The module kswm_wrap will wrap a kernel-based support vector classifier based on the module sklearn.svm from the scikit-learn library, and enable comparison to the deep classifiers. Finally, the module mnist_shootout will evaluate generalization performance on the MNIST dataset for several deep learning techniques.

0a. Introductory notes on deep models

Deep learning models are based on abstract data representations which we get by with a sequence of trained nonlinear transformations. In this and the following laboratory exercises we consider deep models which are discriminative and feed-forward. Discriminative models predict a conditional probability \(P(Y|\mathbf{x})\) of a dependent variable \(Y\) given data \(\mathbf{x}\). Discriminative models are tipically used when we have labeled data suitable for supervised learning. Feed-forward models have a unidirectional information flow which means that intermediate processing results can not be connected back towards the input. The basic deep discriminative model can be viewed either as an augmented logistic regression model or as a multi-layer feed-forward artificial neural network.

Artificial neural networks

Artificial neural networks are machine learning models which may be expressed as a directed graph of uniform scalar processing units called artificial neurons. One of the important goals of artificial neural networks is to define a computation model of biological processes. In other words, artifical neurons strive to understand the mechanism of learning in the brains of living organisms. Artificial neural networks are related to deep learning. The main difference is that deep learning does not have an ambition to model biological processes. Instead, the deep learning studies compositional models of practical significance which may not have any biological interpretation.

Artificial neurons tipically conduct an affine reduction of an input vector, which can be concisely expressed as \(f(\mathbf{w}^\top\mathbf{x}+b)\). Here, the vector \(\mathbf{x}\) defines input variables, the vector \(\mathbf{w}\) and scalar \(b\) represent free parameters which are optimized during training, while the non-linear function \(f\) represents the activation of the artificial neuron. The goal of the activation function \(f\) is to introduce non linearity to the model. If we pick a softmax function, the artificial neuron will conduct a multi class logistic regression. If we pick a sigmoid function \(σ(s)=e^s/(1+e^s)\), the artificial neuron will perform binary logistic regression. For better learning of deep models, the sigmoid is being replaced by the rectified linear unit: \(\mathrm{ReLU}(s) = \max(0, s)\).

Multi-layer feed-forward models

A neural network with one input layer, softmax on the output, and a loss which maximizes the likelihood of its parameters is equivalent to (possibly multi-class) logistic regression. However, in this course we study deep models which are constructed by adding one or more nonlinear transformations between the input and the output. In the first half of this course, we focus on feed forward models which do not include recurrent connections. Such models can be expressed with acyclic computational graphs where nodes correspond to arbitrary operations, while edges model the connectivity. As in logistic regression, deep feed-forward models are trained by optimizing the likelihood of model parameters given the data and the desired outputs. Contrary to logistic regression, the loss function of deep models is not convex, which means there is no guarantee for finding the global optimum.

In this exercise, we focus on fully-connected feed-forward models. This class of models can be viewed as augmented logistic regression where we introduce several latent layers before the softmax classification. The resulting models are able to achieve nonlinear decision boundaries at a cost of non-convex optimization. Fully-connected feed-forward models can also be viewed as multi-level artificial neural networks. In such networks, neurons can be organized in layers \(S_k\) such that neurons of layer \(k\) operate on all neurons of layer \(k-1\).

We can express each layer of a fully-connected feed-forward model as a composition of an affine vector transform and a nonlinear activation function. For simplicity, we define that applying the scalar transfer function to a vector variable results in elementwise application of the transfer function to the input. Then, we arrive to the following concise notation of the k-th layer of a deep fully-connected feed-forward model with ReLU activation:

\( \mathbf{s_k} = \mathbf{W_k}\cdot\mathbf{h_{k-1}}s + \mathbf{b_k} \\ \mathbf{h_k} = \mathrm{ReLU}(\mathbf{s_k}) \)

The following illustration displays two views onto the same two-layer fully-connected model. The classic view is shown on the left, where circles designate affine scalar neurons and ReLU activations. The right figure shows a vectorized computational graph which we use in this course. data

Gradients in a two layer fully connected model

Let's consider how to determine the gradients of the negative log-likelihood loss function in the aforementioned example of a two-layer fully-connected model. We'll express the model in vector equations as follows:

\( \mathbf{s_1} = \mathbf{W_1} \cdot \mathbf{x} + \mathbf{b_1} \\ \mathbf{h_1} = \mathrm{ReLU}(\mathbf{s_1}) \\ \mathbf{s_2} = \mathbf{W_2} \cdot \mathbf{h_1} + \mathbf{b_2} \\ P(Y|\mathbf{x}) = \mathrm{softmax}(\mathbf{s_2}) . \)

Our loss function will be the sum of negative log-likelihoods of the model across all data:

\( L(\mathbf{W_1},\mathbf{b_1}, \mathbf{W_2},\mathbf{b_2}| \mathbf{X}, \mathbf{y}) = \sum_i -\log P(Y=y_i|\mathbf{x}_i) \)

We see that the loss function corresponds to a composition of several simpler functions. The loss \(L\) depends on probabilities \(P\) which depend on linear classification scores of the second layer \(\mathbf{s_2}\) which depends on the hidden layer \(\mathbf{h_1}\) and parameters \(\mathbf{W_2}\) and \(\mathbf{b_2}\). The hidden layer \(\mathbf{h_1}\) depends on the linear score \(\mathbf{s_1}\) which finally depends on parameters \(\mathbf{W_1}\) and \(\mathbf{b_1}\) and the data \(\mathbf{x}\). Therefore we determine the gradients of the loss with respect to the parameters using the chain rule.

Partial derivatives of the loss with respect to the j-th rows of \(\mathbf{W_2}\) and \(\mathbf{b_2}\) will be similar to their counterparts in the multiclass logistic regression (cf. lab exercise 0). We'll exploit the algebraic structure of the problem, which can be concisely expressed as: \( \partial {s_2}_{ij}/ \partial \mathbf{W_2}_{k:} = \partial {s_2}_{ij}/ \partial \mathbf{b_2}_{k:} = 0, \; \forall k \neq j \; . \)

In order to achieve a more compact notation we shall denote the k-th row of the matrix \(\mathbf{W_2}\)) as \(\mathbf{W_2}_{k:}\). Furthermore, we express the result in terms of the matrix of a-posteriori probabilities \(\mathbf{P}\) and the matrix of one-hot coded labels \(\mathbf{Y'}\). that we have introduced in lab exercise 0. Finally we get to the following expressions:

\( \frac{∂L_i}{∂\mathbf{W_2}_{j:}} = \frac{∂L_i}{∂{s_2}_{ij}} \cdot \frac{∂{s_2}_{ij}}{∂\mathbf{W_2}_{j:}} = ({P}_{ij} - {Y'}_{ij}) \cdot \mathbf{h_1}_i^\top \; , \\ \frac{∂L_i}{∂b_{2j}} = \frac{∂L_i}{∂\mathbf{s_2}_{ij}} \cdot \frac{∂\mathbf{s_2}_{ij}}{∂b_{2j}} = ({P}_{ij} - {Y'}_{ij}) \; . \)

In order to get the gradient w.r.t. \(\mathbf{W_1}\) and \(\mathbf{b_1}\) we need to propagate over all components of the second layer. However, this propagation is easily expressed because the Jacobian of the linear layer matches the weight matrix, while the Jacobian of the Relu is a diagonal matrix whose diagonal entries reflect the sign of the corresponding component of the first layer. When we arive to the linear score of the first layer, we can reuse much of the work performed in the second layer. In fact, the dependency patterns of the classification score \(\mathbf{s_2}\) w.r.t. the parameters of the second layer is the same as for the linear score \(\mathbf{s_1}\) w.r.t. the first layer parameters. Thus, the analytical expressions of the partial derivatives \(\partial\mathbf{s}_1/\partial\mathbf{W_1}\) are quite similar to the corresponding expressions from the second layer:

\( \frac{∂L_i}{∂\mathbf{s_1}_{i}} = \frac{∂L_i}{∂\mathbf{s_2}_i} \cdot \frac{∂\mathbf{s_2}_i}{∂\mathbf{h_1}_{i}} \cdot \frac{∂\mathbf{h_1}_{i}}{∂\mathbf{s_1}_{i}} \cdot = (\mathbf{P}_{i:} - \mathbf{Y'}_{i:}) \cdot \mathbf{W_2} \cdot \mathrm{diag}([\![s_{1i:}>0]\!]) \;, \\ \frac{∂L_i}{∂\mathbf{W_1}_{j:}} = \frac{∂L_i}{∂{s_1}_{ij}} \frac{∂\mathbf{s_1}_{ij}}{∂\mathbf{W_1}_{j:}} = \frac{∂L_i}{∂{s_1}_{ij}} \mathbf{x_i}^\top \;, \\ \frac{∂L_i}{∂b_{1j}} = \frac{∂L_i}{∂{s_1}_{ij}} \frac{∂\mathbf{s_1}_{ij}}{∂{b_1}_{j}} = \frac{∂L_i}{∂{s_1}_{ij}} \)

In the following text, we shall use the term gradient both for the partial derivation of the loss with respect to the parameters as well as for individual parts of that vector. The exact meaning will be conveyed by the context. The same convention is used in scientific literature. Thus the expression four gradients, refer to the left sides of the above four equations.

It should be noted here that out ambition is not to efficiently calculate some gradients for some data points. On the contrary, our goal is to efficiently calculate all gradients for many data points, by relying on optimized matrix algebra libraries. We follow this strategy since most of speed improvement potential is in cache optimizations which can not be exploited on small data batches. Hence, we will calculate the gradients of each layer for a large batch of data and all parameter rows (as in logistic regression), with a single matrix multiplication.

However, unlike in logistic regression, in deep models we must decide in which order to calculate the particular gradients. (eg should we first calculate \(\frac{∂L_i}{∂\mathbf{b_1}}\) or \(\frac{∂L_i}{∂\mathbf{b_2}}\)). The answer to this ambiguity is provided by the backpropagation algorithm.

Backpropagation

Vector equations of our model with two fully connected layers reflect a common theme in deep models: we can see that the partial derivation \(\frac{∂L_i}{∂\mathbf{s_2}_i}\) occurs in gradients of the loss w.r.t all four parameters of the model (\(\mathbf{W_1}, \mathbf{b_1}, \mathbf{W_2}, \mathbf{b_2}\)). This can be used to recover the gradients with minimal computational effort. In fact, we may notice that the gradients of the loss function w.r.t. the nodes of the computational graph do not have to be calculated more than once if they are calculated backwards, from the output towards the input of the model. This simple but very efficient approach is formalized by the backprop algorithm.

The backprop algorithm is illustrated in the figure below. The black arrows display the evaluation of the model and the loss function of a given data point. This evaluation is called the forward pass. The red arrows show the gradient calculation as suggested by backprop algorithm. This evaluation is called the backward pass.

data

It appears that all components of the solution to our problem are now known. We know how to calculate gradients w.r.t. certain parameters, as well as the order to do so. However, we wish to emphasize two non trivial details. The first one is the loop through data points. If we would like to enjoy advantages of optimized libraries and avoid iterating in Python, then each gradient should be calculated for a large batch of data at once. If our batch contains N data points,s we'll first calculate the N rows of matrix \( \mathbf{G}_\mathbf{s_2} = [ (\frac{∂L_i}{∂\mathbf{s_2}_{i}})_{i=1}^N ] \), and then 100000 rows of matrix \( \mathbf{G}_\mathbf{h_1} = [ (\frac{∂L_i}{∂\mathbf{h_1}_{i}})_{i=1}^N ] \), etc. Such approach may appear strange, since it heavily increases the memory requirements. However, this is the price for being able to express our algorithms in terms of optimized matrix algorithms. Failing to do so would increase the training by several orders of magnitude.

Another detail is calculating the gradients of the weight matrices. Instead of separate calculation of gradients over rows (as suggested by the equations above), we advise you to use the approach from the lab exercise 0 (section 0d). There we show that the entire matrix of gradients (which in gradient descent is added to the weight matrix) can be expressed by a single matrix multiplication. The weight gradients in the k-th layer \([\partial L/\partial\mathbf{W_k}]\) can be obtained for the whole batch at once by multiplying the transposed gradients of the linear score \(\mathbf{G}_\mathbf{s_k}\) with the input matrix \(\mathbf{H}_{k-1}\). The vector of gradients for the k-th layer bias can be calculated in the similar manner. However, instead of matrix multiplication we need to perform column summation of the matrix \(\mathbf{G}_\mathbf{s_k}\), which can be obtained by np.sum.

To summarize the discussion for this particular case of a two layer network, we would calculate the partial derivatives of the loss function across all data in the batch in the following order:

  1. gradients of the loss w.r.t. linear score of the second layer:
  2. gradients of the loss w.r.t. the second layer parameters
  3. gradients of the loss w.r.t. nonlinear output of the first layer across all data:
  4. gradients of the loss w.r.t. linear score of the first layer across all data
  5. gradient of the loss w.r.t. the first layer parameters:
Data normalization and Parameter Initialization

In models trained with gradient descent, the initial parameter initialization is highly important. For latent layers activated by the sigmoid function, activations must be centered around zero to allow the sigmoid to be effective. For example, if all inputs were positive, and all weights were positive, all sigmoid activations would be permissive, and the layer's effect would be entirely linear. This is not desirable because a combination of linear transformations is, again, a linear transformation, and we know that linear transformations have much lower capacity than a deep composition of nonlinear transformations. The lesson from this discussion is that deep model learning will progress well if:

0b. Introductory notes on PyTorch

PyTorch is an open source library for designing machine learning methods with a special emphasis on the following key functionalities:

Although there are similar other tools (TensorFlow, MXNet, etc.), Pytorch is currently the most popular among researches. It is also suitable for beginers due to its clean architecure and comperhensive documentation.

PyTorch supports various operating systems, but Linux is generally the most well-supported and up-to-date platform for machine learning. For the latest information about PyTorch, you can refer to the official website.

Let's illustrate the design of a program under PyTorch on the following simple example:

import torch
# we define the operation
def f(x, a, b):
    return a*x + b

# we define the variables 
# and build a dynamic computational graph 
# with a forward pass
a = torch.tensor(5., requires_grad=True)
b = torch.tensor(8., requires_grad=True)
x = torch.tensor(2.)
y = f(x, a, b)
s = a ** 2

# a backward pass that calculates the gradients
# for all tensors that set requires_grad=True
y.backward()
s.backward()               # gradient is accumualted
assert x.grad is None      # pytorch does not calculate gradients with respect to x
assert a.grad == x + 2 * a # dy/da + ds/da
assert b.grad == 1         # dy/db + ds/db

# print out results
print(f"y={y}, g_a={a.grad}, g_b={b.grad}")

The first part of the example defines a regular Python function. The return value of this function will seamlessly fit into PyTorch's computational graph.

The second part of the example creates objects a, b, and x of type torch.Tensor, which correspond to nodes in the computational graph. Tensors a and b have the attribute requires_grad=True, which means that PyTorch will compute gradients for them during automatic backward pass. Calling operations *, +, and ** creates new objects of type torch.Tensor, which are also nodes in the computational graph. We will refer to objects of type torch.Tensor as tensors.

During the computation of the values of graph nodes, PyTorch remembers all intermediate results that are necessary for gradient computation. The details of this process are determined by the automatic differentiation algorithm (autograd for short).

The third part of the example computes gradients with respect to the nodes y and x with the backward method . Autograd performs backward propagation all the way back to a and b, thus computing their gradients. Multiple calls to the backward method accumulate gradients in the grad attribute of each tensor declared with requires_grad=True. It is worth noting that the sequence of calls y.backward(); a.backward() achieves the same effect as (y + a).backward().

The grad attribute is also an instance of the torch.Tensor class, but it is typically separate from the computational graph that contains its parent tensor. You can compute higher-order derivatives by calling the backward method with the argument create_graph=True, which requests that the derivatives of tensors are also included in the computational graph.

If you want to recalculate the gradient, for example, for a different x, you need to reset the existing gradient to avoid accumulation. You can achieve this by deleting the grad attribute or by setting it to None, such as a.grad=None.

If you don't need to compute gradients for a specific computation, it's a good practice to express it within the torch.no_grad() context manager, which disables autograd (making PyTorch behave like NumPy). Here is an example of a procedure that calculates a confusion matrix based on vectors of true labels y_true and predictions y_pred, and it returns a confusion matrix of dimensions class_count x class_count.

import torch.nn.functional as F

def multiclass_confusion_matrix(y_true, y_pred, class_count):
  with torch.no_grad():
    y_pred = F.one_hot(y_pred, class_count)
    cm = torch.zeros([class_count] * 2, dtype=torch.int64, device=y_true.device)
    for c in range(class_count):
      cm[c, :] = y_pred[y_true == c, :].sum(0)
    return cm  

This example shows that PyTorch enables the copying of tensors (and computations) between different platforms/devices using the optional device argument, which is accepted by all PyTorch functions that create new tensors. Note that you can also specify the data type using the optional dtype argument. Instead of the demonstrated torch.zeros function call, you could specify explicit conversion, like torch.zeros([class_count] * 2).to(dtype=torch.int64, device=y_true.device), which would yield the same result but also leads to unnecessary creation of an intermediate tensor. Typical values for the device argument are torch.device('cpu') (the main processor and memory) or torch.device('cuda:0') (the first GPU under the CUDA platform). You can also specify the device as a string: device='cpu' or device='cuda:0'.

You may find more information in the Pytorch official documentation.

0c. Pytorch in machine learning

A program that uses PyTorch typically consists of the following components:

  1. a model represented by an object derived from torch.nn.Module, which typically contains other modules with parameters,
  2. procedures for loading and processing input data,
  3. the learning loop.

Procedures for loading and processing data typically include:

  1. A dataset - usually an object derived from torch.utils.data.Dataset.
  2. Data preprocessing components on the CPU.
  3. Data sampling components that determine the order of loading (such as torch.utils.data.Sampler).
  4. A component for parallel data loading and preprocessing (torch.utils.data.DataLoader).

The elements of a typical learning algorithm include:

  1. Procedures for initializing parameters.
  2. A loss function.
  3. An optimization algorithm - typically an object derived from torch.optim.Optimizer.
  4. A training loop that fetches data, computes loss gradients, and applies optimization steps to update the parameters.

The following code shows an example of a model which performs an affine transformation:

import torch

class Affine(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.out_features = out_features
    self.linear = torch.nn.Linear(in_features, out_features, bias=False)
    self.bias = torch.nn.Parameter(torch.zeros(out_features))

  def forward(self, input):
    return self.linear(input) + self.bias

The example includes a submodel of type torch.nn.Linear (which inherently supports bias, although it's not used in this example) and a parameter of type torch.nn.Parameter. The torch.nn.Parameter type is derived from torch.Tensor and is primarily used to distinguish parameters from other tensors (the requires_grad attribute is set to True by default). The torch.nn.Module class defines methods that return iterators over modules(modules), submodules (children), parameters (parapeters) etc. Methods with the prefix named_ return pairs of names (paths) and objects, as illustrated in the following example.

>>> affine = Affine(3, 4)
>>> print(list(affine.named_parameters()))
[('bias',
  Parameter containing:
  tensor([0.000, 0.000, 0.000], requires_grad=True)),
 ('linear.weight',
  Parameter containing:
  tensor([[-0.2684,  0.2126, -0.4430],
          [ 0.3446, -0.2018, -0.4346],
          [-0.4756, -0.3453,  0.1401],
          [ 0.3257,  0.0911, -0.1267]], requires_grad=True))]

We typically design modules to operate on mini-batches of data. For example, a call like affine(torch.randn(5, 3)) results in a tensor of dimensions (5, 4), where torch.randn(5, 3) creates a matrix of elements that were randomly from a normal distribution.

More information about modules can be found in the official documentation. Various procedures for parameter initialization can be found in the torch.nn.init package.

The next example demonstrates the basics of data loading:

import numpy as np
import torch
from torch.utils.data import DataLoader

dataset = [(torch.randn(4, 4), torch.randint(5, size=())) for _ in range(25)]
dataset = [(x.numpy(), y.numpy()) for x, y in dataset]
loader = DataLoader(dataset, batch_size=8, shuffle=False,
                    num_workers=0, collate_fn=None, drop_last=False)
for x, y in loader:
  print(x.shape, y.shape)

The example first generates a random dataset of 25 random pairs of 4x4 matrices and scalars. For illustrative purposes, the data is converted to the numpy.ndarray type without copying. The dataset is then passed to the constructor of the DataLoader class, whose instance facilitates iteration over mini-batches.

Here are some important arguments that the constructor accepts: