Skip to article frontmatterSkip to article content

4 Supervised learning

4.1Introduction

This section will cover the details of implementing the fit function above: That is, how to use a dataset of labelled samples (x1,y1),,(xN,yN)(x_1, y_1), \dots, (x_N, y_N) to find a function ff that minimizes the empirical risk. This requires two ingredients:

  1. A function class F\mathcal{F} to search over
  2. A fitting method for minimizing the empirical risk over this class

The two main function classes we will cover are linear models and neural networks. Both of these function classes are parameterized by some parameters θ, and the fitting method will search over these parameters to minimize the empirical risk:

The most common fitting method for parameterized models is gradient descent.

from jaxtyping import Float, Array
from collections.abc import Callable
Params = Float[Array, " D"]


def gradient_descent(
    loss: Callable[[Params], float],
    θ_init: Params,
    η: float,
    epochs: int,
):
    """
    Run gradient descent to minimize the given loss function
    (expressed in terms of the parameters).
    """
    θ = θ_init
    for _ in range(epochs):
        θ = θ - η * grad(loss)(θ)
    return θ

4.2Linear regression

In linear regression, we assume that the function ff is linear in the parameters:

F={xθxθRD}\mathcal{F} = \{ x \mapsto \theta^\top x \mid \theta \in \mathbb{R}^D \}

This function class is extremely simple and only contains linear functions. To expand its expressivity, we can transform the input xx using some feature function ϕ, i.e. x~=ϕ(x)\widetilde x = \phi(x), and then fit a linear model in the transformed space instead.

def fit_linear(X: Float[Array, "N D"], y: Float[Array, " N"], φ=lambda x: x):
    """Fit a linear model to the given dataset using ordinary least squares."""
    X = vmap(φ)(X)
    θ = np.linalg.lstsq(X, y, rcond=None)[0]
    return lambda x: np.dot(φ(x), θ)

4.3Neural networks

In neural networks, we assume that the function ff is a composition of linear functions (represented by matrices WiW_i) and non-linear activation functions (denoted by σ):

F={xσ(WLσ(WL1σ(W1x+b1)+bL1)+bL)}\mathcal{F} = \{ x \mapsto \sigma(W_L \sigma(W_{L-1} \dots \sigma(W_1 x + b_1) \dots + b_{L-1}) + b_L) \}

where WiRDi+1×DiW_i \in \mathbb{R}^{D_{i+1} \times D_i} and biRDi+1b_i \in \mathbb{R}^{D_{i+1}} are the parameters of the ii-th layer, and σ is the activation function.

This function class is much more expressive and contains many more parameters. This makes it more susceptible to overfitting on smaller datasets, but also allows it to represent more complex functions. In practice, however, neural networks exhibit interesting phenomena during training, and are often able to generalize well even with many parameters.

Another reason for their popularity is the efficient backpropagation algorithm for computing the gradient of the empirical risk with respect to the parameters. Essentially, the hierarchical structure of the neural network, i.e. computing the output of the network as a composition of functions, allows us to use the chain rule to compute the gradient of the output with respect to the parameters of each layer.

Nielsen (2015) provides a comprehensive introduction to neural networks and backpropagation.

References
  1. Nielsen, M. A. (2015). Neural Networks and Deep Learning. Determination Press.