4.1Introduction¶
This section will cover the details of implementing the fit
function above:
That is, how to use a dataset of labelled samples to find a function that minimizes the empirical risk.
This requires two ingredients:
- A function class to search over
- A fitting method for minimizing the empirical risk over this class
The two main function classes we will cover are linear models and neural networks. Both of these function classes are parameterized by some parameters θ, and the fitting method will search over these parameters to minimize the empirical risk:
The most common fitting method for parameterized models is gradient descent.
from jaxtyping import Float, Array
from collections.abc import Callable
Params = Float[Array, " D"]
def gradient_descent(
loss: Callable[[Params], float],
θ_init: Params,
η: float,
epochs: int,
):
"""
Run gradient descent to minimize the given loss function
(expressed in terms of the parameters).
"""
θ = θ_init
for _ in range(epochs):
θ = θ - η * grad(loss)(θ)
return θ
4.2Linear regression¶
In linear regression, we assume that the function is linear in the parameters:
This function class is extremely simple and only contains linear functions. To expand its expressivity, we can transform the input using some feature function ϕ, i.e. , and then fit a linear model in the transformed space instead.
def fit_linear(X: Float[Array, "N D"], y: Float[Array, " N"], φ=lambda x: x):
"""Fit a linear model to the given dataset using ordinary least squares."""
X = vmap(φ)(X)
θ = np.linalg.lstsq(X, y, rcond=None)[0]
return lambda x: np.dot(φ(x), θ)
4.3Neural networks¶
In neural networks, we assume that the function is a composition of linear functions (represented by matrices ) and non-linear activation functions (denoted by σ):
where and are the parameters of the -th layer, and σ is the activation function.
This function class is much more expressive and contains many more parameters. This makes it more susceptible to overfitting on smaller datasets, but also allows it to represent more complex functions. In practice, however, neural networks exhibit interesting phenomena during training, and are often able to generalize well even with many parameters.
Another reason for their popularity is the efficient backpropagation algorithm for computing the gradient of the empirical risk with respect to the parameters. Essentially, the hierarchical structure of the neural network, i.e. computing the output of the network as a composition of functions, allows us to use the chain rule to compute the gradient of the output with respect to the parameters of each layer.
Nielsen (2015) provides a comprehensive introduction to neural networks and backpropagation.
- Nielsen, M. A. (2015). Neural Networks and Deep Learning. Determination Press.