Skip to article frontmatterSkip to article content

6 Policy Gradient Methods

6.1Introduction

The core task of RL is finding the optimal policy in a given environment. This is essentially an optimization problem: out of some space of policies, we want to find the one that achieves the maximum total reward (in expectation).

It’s typically intractable to compute the optimal policy exactly in some finite number of steps. Instead, policy optimization algorithms start from some randomly initialized policy, and then improve it step by step. We’ve already seen some examples of these, namely Section 1.5.3.2 for finite MDPs and Section 2.6.4 in continuous control.

In particular, we often use policies that can be described by some finite set of parameters. We will see some examples in Section 6.3.1. For such parameterized policies, we can approximate the policy gradient: the gradient of the expected total reward with respect to the parameters. This tells us the direction the parameters should be updated to achieve a higher expected total reward. Policy gradient methods are responsible for groundbreaking applications including AlphaGo, OpenAI Five, and large language models, many of which use policies parameterized as deep neural networks.

  1. We begin the chapter with a short review of gradient ascent, a general optimization method.
  2. We’ll then see how to estimate the policy gradient, enabling us to apply (stochastic) gradient ascent in the RL setting.
  3. Then we’ll explore some proximal optimization techniques that ensure the steps taken are “not too large”. This is helpful to stabilize training and widely used in practice.
%load_ext autoreload
%autoreload 2
from utils import plt, Array, Float, Callable, jax, jnp, latex, gym
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 from utils import plt, Array, Float, Callable, jax, jnp, latex, gym

ModuleNotFoundError: No module named 'utils'

6.2Gradient Ascent

Gradient ascent is a general optimization algorithm for any differentiable function. A suitable analogy for this algorithm is hiking up a mountain, where you keep taking steps in the steepest direction upwards. Here, your vertical position yy is the function being optimized, and your horizontal position (x,z)(x, z) is the input to the function. The slope of the mountain at your current position is given by the gradient, written y(x,z)R2\nabla y(x, z) \in \mathbb{R}^2.

For differentiable functions, this can be thought of as the vector of partial derivatives,

y(x,z)=(yxyz).\nabla y(x, z) = \begin{pmatrix} \frac{\partial y}{\partial x} \\ \frac{\partial y}{\partial z} \end{pmatrix}.

To calculate the slope (aka “directional derivative”) of the mountain in a given direction (Δx,Δz)(\Delta x, \Delta z), you take the dot product of the difference vector with the gradient. This means that the direction with the highest slope is exactly the gradient itself, so we can describe the gradient ascent algorithm as follows:

where kk denotes the iteration of the algorithm and η>0\eta > 0 is a “step size” hyperparameter that controls the size of the steps we take. (Note that we could also vary the step size across iterations, that is, η0,,ηK\eta^0, \dots, \eta^K.)

The case of a two-dimensional input is easy to visualize. But this idea can be straightforwardly extended to higher-dimensional inputs.

From now on, we’ll use JJ to denote the function we’re trying to maximize, and θ to denote the parameters being optimized over. (In the above example, θ=(xz)\theta = \begin{pmatrix} x & z \end{pmatrix}^\top).

Notice that our parameters will stop changing once J(θ)=0.\nabla J(\theta) = 0. Once we reach this stationary point, our current parameters are ‘locally optimal’ in some sense; it’s impossible to increase the function by moving in any direction. If JJ is convex, then the only point where this happens is at the global optimum. Otherwise, if JJ is nonconvex, the best we can hope for is a local optimum.

6.2.1Computing derivatives

How does a computer compute the gradient of a function?

One way is symbolic differentiation, which is similar to the way you might compute it by hand: the computer applies a list of rules to transform the symbols involved. Python’s sympy package supports symbolic differentiation. However, functions implemented in code may not always have a straightforward symbolic representation.

Another way is numerical differentiation, which is based on the limit definition of a (directional) derivative:

uJ(x)=limε0J(x+εu)J(x)ε\nabla_{\boldsymbol{u}} J(\boldsymbol{x}) = \lim_{\varepsilon \to 0} \frac{J(\boldsymbol{x} + \varepsilon \boldsymbol{u}) - J(\boldsymbol{x})}{\varepsilon}

Then, we can substitute a small value of ε\varepsilon on the r.h.s. to approximate the directional derivative. How small, though? If we need an accurate estimate, we may need such a small value of ε\varepsilon that typical computers will run into rounding errors. Also, to compute the full gradient, we would need to compute the r.h.s. once for each input dimension. This is an issue if computing JJ is expensive.

Automatic differentiation achieves the best of both worlds. Like symbolic differentiation, we manually implement the derivative rules for a few basic operations. However, instead of executing these on the symbols, we execute them on the values when the function gets called, like in numerical differentiation. This allows us to differentiate through programming constructs such as branches or loops, and doesn’t involve any arbitrarily small values. Baydin et al. (2018) provides an accessible survey of automatic differentiation.

6.2.2Stochastic gradient ascent

In real applications, computing the gradient of the target function is not so simple. As an example from supervised learning, J(θ)J(\theta) might be the sum of squared prediction errors across an entire training dataset. However, if our dataset is very large, it might not fit into our computer’s memory! Typically in these cases, we compute some estimate of the gradient at each step, and walk in that direction instead. This is called stochastic gradient ascent. In the SL example above, we might randomly choose a minibatch of samples and use them to estimate the true prediction error. (This approach is known as minibatch SGD.)

def sgd(
    theta_init: Float[Array, " D"],
    estimate_gradient: Callable[[Float[Array, " D"]], Float[Array, " D"]],
    eta: float,
    n_steps: int,
):
    """Perform `n_steps` steps of SGD.

    `estimate_gradient` eats the current parameters and returns an estimate of the objective function's gradient at those parameters.
    """
    theta = theta_init
    for step in range(n_steps):
        theta += eta * estimate_gradient(theta)
    return theta

latex(sgd)

What makes one gradient estimator better than another? Ideally, we want this estimator to be unbiased; that is, on average, it matches a single true gradient step:

E[~J(θ)]=J(θ).\E [\tilde \nabla J(\theta)] = \nabla J(\theta).

We also want the variance of the estimator to be low so that its performance doesn’t change drastically at each step.

We can actually show that, for many “nice” functions, in a finite number of steps, SGD will find a θ that is “close” to a stationary point. In another perspective, for such functions, the local “landscape” of JJ around θ becomes flatter and flatter the longer we run SGD.

We’ll now see a concrete application of gradient ascent in the context of policy optimization.

6.3Policy (stochastic) gradient ascent

Remember that in RL, the primary goal is to find the optimal policy that achieves the maximimum total reward, which we can express using the value function we defined in Definition 1.6:

J(π):=Es0μ0Vπ(s0)=Eτρπh=0H1r(sh,ah)\begin{aligned} J(\pi) := \E_{s_0 \sim \mu_0} V^{\pi} (s_0) = & \E_{\tau \sim \rho^\pi} \sum_{\hi=0}^{\hor-1} r(s_\hi, a_\hi) \end{aligned}

where ρπ\rho^\pi is the distribution over trajectories induced by π (see Definition 1.5).

(Note that we’ll continue to work in the undiscounted, finite-horizon case. Analogous results hold for the discounted, infinite-horizon setup.)

As shown by the notation, this is exactly the function JJ that we want to maximize using gradient ascent. What variables are we optimizing over in this problem? Well, the objective function JJ is a function of the policy π, but in general, π is a function, and optimizing over the entire space of arbitrary input-output mappings would be intractable. Instead, we need to describe π in terms of some finite set of parameters θ.

6.3.1Example policy parameterizations

What are some ways we could parameterize our policy?

Now that we have seen some examples of parameterized policies, we will write the total reward in terms of the parameters, overloading notation and letting ρθ:=ρπθ\rho_\theta := \rho^{\pi_\theta}:

J(θ)=EτρθR(τ)J(\theta) = \E_{\tau \sim \rho_\theta} R(\tau)

where R(τ)=h=0H1r(sh,ah)R(\tau) = \sum_{\hi=0}^{\hor-1} r(s_\hi, a_\hi) denotes the total reward in the trajectory.

Now how do we maximize this function (the expected total reward) over the parameters? One simple idea would be to directly apply gradient ascent:

θk+1=θk+ηJ(θk).\theta^{k+1} = \theta^k + \eta \nabla J(\theta^k).

In order to apply this technique, we need to be able to evaluate the gradient J(θ).\nabla J(\theta). But J(θ)J(\theta) is very difficult, or even intractable, to compute exactly, since it involves taking an expectation over all possible trajectories τ.\tau. Can we rewrite it in a form that’s more convenient to implement?

6.3.2Importance Sampling

There is a general trick called importance sampling for evaluating difficult expectations. Suppose we want to estimate Exp[f(x)]\E_{x \sim p}[f(x)] where pp is hard or expensive to sample from, but easy to evaluate the likelihood p(x)p(x) of. Suppose that we can easily sample from a different distribution qq. Since an expectation is just a weighted average, we can sample xx from qq, compute f(x)f(x), and then reweight the results: if xx is very likely under pp but unlikely under qq, we should boost its weighting, and if it is common under qq but uncommon under pp, we should lower its weighting. The reweighting factor is exactly the likelihood ratio between the target distribution pp and the sampling distribution qq:

Exp[f(x)]=xXf(x)p(x)=xXf(x)p(x)q(x)q(x)=Exq[p(x)q(x)f(x)].\E_{x \sim p}[f(x)] = \sum_{x \in \mathcal{X}} f(x) p(x) = \sum_{x \in \mathcal{X}} f(x) \frac{p(x)}{q(x)} q(x) = \E_{x \sim q} \left[ \frac{p(x)}{q(x)} f(x) \right].

Doesn’t this seem too good to be true? If there were no drawbacks, we could use this to estimate any expectation of any function on any arbitrary distribution! The drawback is that the variance may be very large due to the likelihood ratio term. If there are values of xx that are very rare in the sampling distribution qq, but common under pp, then the likelihood ratio p(x)/q(x)p(x)/q(x) will cause the variance to blow up.

6.4The REINFORCE policy gradient

Returning to RL, suppose there is some trajectory distribution ρ(τ)\rho(\tau) that is easy to sample from, such as a database of existing trajectories. We can then rewrite J(θ)\nabla J(\theta), a.k.a. the policy gradient, as follows. All gradients are being taken with respect to θ.

J(θ)=Eτρθ[R(τ)]=Eτρ[ρθ(τ)ρ(τ)R(τ)]likelihood ratio trick=Eτρ[ρθ(τ)ρ(τ)R(τ)]switching gradient and expectation\begin{aligned} \nabla J(\theta) & = \nabla \E_{\tau \sim \rho_\theta} [ R(\tau) ] \\ & = \nabla \E_{\tau \sim \rho} \left[ \frac{\rho_\theta(\tau)}{\rho(\tau)} R(\tau) \right] & & \text{likelihood ratio trick} \\ & = \E_{\tau \sim \rho} \left[ \frac{\nabla \rho_\theta(\tau)}{\rho(\tau)} R(\tau) \right] & & \text{switching gradient and expectation} \end{aligned}

Note that for ρ=ρθ\rho = \rho_\theta, the inside term becomes

J(θ)=Eτρθ[logρθ(τ)R(τ)].\nabla J(\theta) = \E_{\tau \sim \rho_\theta} [ \nabla \log \rho_\theta(\tau) \cdot R(\tau)].

(The order of operations is (logρθ)(τ)\nabla (\log \rho_\theta)(\tau).)

Recall that when the state transitions are Markov (i.e. sts_{t} only depends on st1,at1s_{t-1}, a_{t-1}) and the policy is time-homogeneous (i.e. ahπθ(sh)a_\hi \sim \pi_\theta (s_\hi)), we can write out the likelihood of a trajectory under the policy πθ\pi_\theta autoregressively, as in Definition 1.5. Taking the log of the trajectory likelihood turns it into a sum of terms:

logρθ(τ)=logμ(s0)+h=0H1logπθ(ahsh)+logP(sh+1sh,ah)\log \rho_\theta(\tau) = \log \mu(s_0) + \sum_{\hi=0}^{\hor-1} \log \pi_\theta(a_\hi \mid s_\hi) + \log P(s_{\hi+1} \mid s_\hi, a_\hi)

When we take the gradient with respect to the parameters θ, only the πθ(ahsh)\pi_\theta(a_\hi | s_\hi) terms depend on θ. This gives the following expression for the policy gradient, known as the “REINFORCE” policy gradient Williams (1992):

J(θ)=Eτρθ[h=0H1θlogπθ(ahsh)R(τ)]\begin{aligned} \nabla J(\theta) = \E_{\tau \sim \rho_\theta} \left[ \sum_{\hi=0}^{\hor-1} \nabla_\theta \log \pi_{\theta}(a_\hi | s_\hi) R(\tau) \right] \end{aligned}

This expression allows us to estimate the gradient by sampling a few sample trajectories from πθ,\pi_\theta, calculating the likelihoods of the chosen actions, and substituting these into the expression inside the brackets of (6.18). Then we can update the parameters θ in this direction to perform stochastic gradient ascent.

The rest of this chapter investigates ways to reduce the variance of this estimator by subtracting off certain correlated quantities.

@latex
def estimate_gradient_reinforce_pseudocode(env: gym.Env, pi, theta: Float[Array, " D"]):
    tau = sample_trajectory(env, pi(theta))
    nabla_hat = jnp.zeros_like(theta)
    total_reward = sum(r for _s, _a, r in tau)
    for s, a, r in tau:
        def policy_log_likelihood(theta: Float[Array, " D"]) -> float:
            return log(pi(theta)(s, a))
        nabla_hat += jax.grad(policy_log_likelihood)(theta) * total_reward
    return nabla_hat

estimate_gradient_reinforce_pseudocode

For some intuition into how this method works, recall that we update our parameters according to

θt+1=θt+ηJ(θt)=θt+ηEτρθt[logρθt(τ)R(τ)].\begin{aligned} \theta_{t+1} &= \theta_t + \eta \nabla J(\theta_t) \\ &= \theta_t + \eta \E_{\tau \sim \rho_{\theta_t}} [\nabla \log \rho_{\theta_t}(\tau) \cdot R(\tau)]. \end{aligned}

Consider the “good” trajectories where R(τ)R(\tau) is large. Then θ gets updated so that these trajectories become more likely. To see why, recall that ρθ(τ)\rho_{\theta}(\tau) is the likelihood of the trajectory τ under the policy πθ,\pi_\theta, so the gradient points in the direction that makes τ more likely.

6.5Baselines and advantages

A central idea from supervised learning is the bias-variance decomposition, which shows that the mean squared error of an estimator is the sum of its squared bias and its variance. The REINFORCE gradient estimator (6.18) is already unbiased, meaning that its expectation over trajectories is the true policy gradient. Can we find ways to reduce its variance as well?

As a first step, consider that the action taken at step tt does not affect the reward from previous timesteps, since they’re already in the past. You can also show rigorously that this is the case, and that we only need to consider the present and future rewards to calculate the policy gradient:

J(θ)=Eτρθ[h=0H1θlogπθ(ahsh)h=hH1r(sh,ah)]\nabla J(\theta) = \E_{\tau \sim \rho_\theta} \left[ \sum_{\hi=0}^{\hor-1} \nabla_\theta \log \pi_{\theta}(a_\hi | s_\hi) \sum_{\hi' = \hi}^{\hor-1} r(s_{\hi'}, a_{\hi'}) \right]

Furthermore, by a conditioning argument, we can replace the inner sum over remaining rewards with the policy’s Q-function, evaluated at the current state:

J(θ)=Eτρθ[h=0H1θlogπθ(ahsh)Qπθ(sh,ah)]\nabla J(\theta) = \E_{\tau \sim \rho_\theta} \left[ \sum_{\hi=0}^{\hor-1} \nabla_\theta \log \pi_{\theta}(a_\hi | s_\hi) Q^{\pi_\theta}(s_{\hi}, a_{\hi}) \right]

Exercise: Prove that this is equivalent to the previous definitions. What modification to the expression must be made for the discounted, infinite-horizon setting?

We can further reduce variance by subtracting a baseline function bh:SRb_\hi : \mathcal{S} \to \mathbb{R} at each timestep h\hi. This modifies the policy gradient as follows:

J(θ)=Eτρθ[h=0H1logπθ(ahsh)(Qπθ(sh,ah)bh(sh))].\nabla J(\theta) = \E_{\tau \sim \rho_\theta} \left[ \sum_{\hi=0}^{H-1} \nabla \log \pi_\theta (a_\hi | s_\hi) \left( Q^{\pi_\theta}(s_\hi, a_\hi) - b_\hi(s_\hi) \right) \right].

(Again, you should try to prove that this equality still holds.) For example, we might want bhb_\hi to estimate the average reward-to-go at a given timestep:

bhθ=EτρθRh(τ).b_\hi^\theta = \E_{\tau \sim \rho_\theta} R_\hi(\tau).

As a better baseline, we could instead choose the value function. Note that the random variable Qhπ(s,a)Vhπ(s),Q^\pi_\hi(s, a) - V^\pi_\hi(s), where the randomness is taken over the actions, is centered around zero. (Recall Vhπ(s)=EaπQhπ(s,a).V^\pi_\hi(s) = \E_{a \sim \pi} Q^\pi_\hi(s, a).) This quantity matches the intuition given in Note 6.1: it is positive for actions that are better than average (in state ss), and negative for actions that are worse than average. In fact, this quantity has a particular name: the advantage function.

This measures how much better this action does than the average for that policy. (Note that for an optimal policy π,\pi^\star, the advantage of a given state-action pair is always zero or negative.)

We can now express the policy gradient as follows. Note that the advantage function effectively replaces the QQ-function from (6.25):

J(θ)=Eτρθ[h=0H1logπθ(ahsh)Ahπθ(sh,ah)].\nabla J(\theta) = \E_{\tau \sim \rho_\theta} \left[ \sum_{\hi=0}^{\hor-1} \nabla \log \pi_\theta(a_\hi | s_\hi) A^{\pi_\theta}_\hi (s_\hi, a_\hi) \right].

Note that to avoid correlations between the gradient estimator and the value estimator (i.e. baseline), we must estimate them with independently sampled trajectories:

def pg_with_learned_baseline(env: gym.Env, pi, eta: float, theta_init, K: int, N: int) -> Float[Array, " D"]:
    theta = theta_init
    for k in range(K):
        trajectories = sample_trajectories(env, pi(theta), N)
        V_hat = fit_value(trajectories)
        tau = sample_trajectories(env, pi(theta), 1)
        nabla_hat = jnp.zeros_like(theta)  # gradient estimator

        for h, (s, a) in enumerate(tau):
            def log_likelihood(theta_opt):
                return jnp.log(pi(theta_opt)(s, a))
            nabla_hat = nabla_hat + jax.grad(log_likelihood)(theta) * (return_to_go(tau, h) - V_hat(s))
        
        theta = theta + eta * nabla_hat
    return theta

latex(pg_with_learned_baseline)

Note that you could also generalize this by allowing the learning rate η to vary across steps, or take multiple trajectories τ and compute the sample average of the gradient estimates.

The baseline estimation step fit_value can be done using any appropriate supervised learning algorithm. Note that the gradient estimator will be unbiased regardless of the baseline.

6.6Comparing policy gradient algorithms to policy iteration

What advantages does the policy gradient algorithm have over the policy iteration algorithms covered in Section 1.5.3.2?

To analyze the difference between them, we’ll make use of the performance difference lemma, which provides an expression for comparing the difference between two value functions.

The PDL gives insight into why fitted approaches such as PI don’t work as well in the “full” RL setting. To see why, let’s consider a single iteration of policy iteration, where policy π gets updated to π~\tilde \pi. We’ll assume these policies are deterministic. Suppose the new policy π~\tilde \pi chooses some action with a negative advantage with respect to π. That is, when acting according to π, taking the action from π~\tilde \pi would perform worse than expected. Define Δ\Delta_\infty to be the most negative advantage, that is, Δ=minsSAhπ(s,π~(s))\Delta_\infty = \min_{s \in \mathcal{S}} A^{\pi}_\hi(s, \tilde \pi(s)). Plugging this into the Theorem 6.1 gives

V0π~(s)V0π(s)=Eτρπ~,s[h=0H1Ahπ(sh,ah)]HΔV0π~(s)V0π(s)HΔ.\begin{aligned} V_0^{\tilde \pi}(s) - V_0^{\pi}(s) &= \E_{\tau \sim \rho_{\tilde \pi, s}} \left[ \sum_{\hi=0}^{\hor-1} A_\hi^{\pi}(s_\hi, a_\hi) \right] \\ &\ge H \Delta_\infty \\ V_0^{\tilde \pi}(s) &\ge V_0^{\pi}(s) - H|\Delta_\infty|. \end{aligned}

That is, for some state ss, the lower bound on the performance of π~\tilde \pi is lower than the performance of π. This doesn’t state that π~\tilde \pi will necessarily perform worse than π, only suggests that it might be possible. If these worst case states do exist, though, PI does not avoid situations where the new policy often visits them; It does not enforce that the trajectory distributions ρπ\rho_\pi and ρπ~\rho_{\tilde \pi} be close to each other. In other words, the “training distribution” that our prediction rule is fitted on, ρπ\rho_\pi, may differ significantly from the “evaluation distribution” ρπ~\rho_{\tilde \pi}.

On the other hand, policy gradient methods do, albeit implicitly, encourage ρπ\rho_\pi and ρπ~\rho_{\tilde \pi} to be similar. Suppose that the mapping from policy parameters to trajectory distributions is relatively smooth. Then, by adjusting the parameters only a small distance, the new policy will also have a similar trajectory distribution. But this is not very rigorous, and in practice the parameter-to-distribution mapping may not be so smooth. Can we constrain the distance between the resulting distributions more explicitly?

This brings us to the next three methods:

  • trust region policy optimization (TRPO), which explicitly constrains the difference between the distributions before and after each step;
  • the natural policy gradient (NPG), a first-order approximation of TRPO;
  • proximal policy optimization (PPO), a “soft relaxation” of TRPO.

6.7Trust region policy optimization

We saw above that policy gradient methods are effective because they implicitly constrain how much the policy changes at each iteration. Can we design an algorithm that explicitly constrains the “step size”? That is, we want to improve the policy as much as possible, measured in terms of the r.h.s. of the Theorem 6.1, while ensuring that its trajectory distribution does not change too much:

θk+1argmaxθoptEs0,,sH1πk[h=0H1Eahπθopt(sh)Aπk(sh,ah)]where distance(ρθopt,ρθk)<δ\begin{aligned} \theta^{k+1} &\gets \arg\max_{\theta^{\text{opt}}} \E_{s_0, \dots, s_{H-1} \sim \pi^{k}} \left[ \sum_{\hi=0}^{\hor-1} \E_{a_\hi \sim \pi^{\theta^\text{opt}}(s_\hi)} A^{\pi^{k}}(s_\hi, a_\hi) \right] \\ & \text{where } \text{distance}(\rho_{\theta^{\text{opt}}}, \rho_{\theta^k}) < \delta \end{aligned}

Note that we have made a small change to the r.h.s. expression: we use the states sampled from the old policy, and only use the actions from the new policy. It would be computationally infeasible to sample entire trajectories from πθ\pi_\theta as we are optimizing over θ. On the other hand, if πθ\pi_\theta returns a vector representing a probability distribution over actions, then evaluating the expected advantage with respect to this distribution only requires taking a dot product. This approximation also matches the r.h.s. of the PDL to first order in θ. (We will elaborate more on this later.)

How do we describe the distance between ρθopt\rho_{\theta^{\text{opt}}} and ρθk\rho_{\theta^k}? We’ll use the Kullback-Leibler divergence (KLD):

Both the objective function and the KLD constraint involve a weighted average over the space of all trajectories. This is intractable in general, so we need to estimate the expectation. As before, we can do this by taking an empirical average over samples from the trajectory distribution. This gives us the following pseudocode:

def kl_div_trajectories(pi, theta_1, theta_2, trajectories):
    """Assume trajectories are sampled from pi(theta_1)."""
    kl_div = 0
    for tau in trajectories:
        for s, a, _r in tau:
            kl_div += jnp.log(pi(theta_1)(s, a)) - jnp.log(pi(theta_2)(s, a))
    return kl_div / len(trajectories)

latex(kl_div_trajectories)
def trpo(env, δ, theta_init, n_interactions):
    theta = theta_init
    for k in range(K):
        trajectories = sample_trajectories(env, pi(theta), n_interactions)
        A_hat = fit_advantage(trajectories)
        
        def approximate_gain(theta_opt):
            A_total = 0
            for tau in trajectories:
                for s, _a, _r in tau:
                    for a in env.action_space:
                        A_total += pi(theta)(s, a) * A_hat(s, a)
            return A_total
        
        def constraint(theta_opt):
            return kl_div_trajectories(pi, theta, theta_opt, trajectories) <= δ
        
        theta = optimize(approximate_gain, constraint)

    return theta

latex(trpo)

The above isn’t entirely complete: we still need to solve the actual optimization problem at each step. Unless we know additional properties of the problem, this might be an intractable optimization. Do we need to solve it exactly, though? Instead, if we assume that both the objective function and the constraint are somewhat smooth in terms of the policy parameters, we can use their Taylor expansions to give us a simpler optimization problem with a closed-form solution. This brings us to the natural policy gradient algorithm.

6.8Natural policy gradient

We take a linear (first-order) approximation to the objective function and a quadratic (second-order) approximation to the KL divergence constraint about the current estimate θk\theta^k. This results in the optimization problem

maxθθJ(πθk)(θθk)where 12(θθk)Fθk(θθk)δ\begin{gathered} \max_\theta \nabla_\theta J(\pi_{\theta^k})^\top (\theta - \theta^k) \\ \text{where } \frac{1}{2} (\theta - \theta^k)^\top F_{\theta^k} (\theta - \theta^k) \le \delta \end{gathered}

where FθkF_{\theta^k} is the Fisher information matrix defined below.

This is a convex optimization problem with a closed-form solution. To see why, it helps to visualize the case where θ is two-dimensional: the constraint describes the inside of an ellipse, and the objective function is linear, so we can find the extreme point on the boundary of the ellipse. We recommend Boyd & Vandenberghe (2004) for a comprehensive treatment of convex optimization.

More generally, for a higher-dimensional θ, we can compute the global optima by setting the gradient of the Lagrangian to zero:

L(θ,α)=J(πθk)(θθk)α[12(θθk)Fθk(θθk)δ]L(θk+1,α):=0    J(πθk)=αFθk(θk+1θk)θk+1=θk+ηFθk1J(πθk)where η=2δJ(πθk)Fθk1J(πθk)\begin{aligned} \mathcal{L}(\theta, \alpha) & = \nabla J(\pi_{\theta^k})^\top (\theta - \theta^k) - \alpha \left[ \frac{1}{2} (\theta - \theta^k)^\top F_{\theta^k} (\theta - \theta^k) - \delta \right] \\ \nabla \mathcal{L}(\theta^{k+1}, \alpha) & := 0 \\ \implies \nabla J(\pi_{\theta^k}) & = \alpha F_{\theta^k} (\theta^{k+1} - \theta^k) \\ \theta^{k+1} & = \theta^k + \eta F_{\theta^k}^{-1} \nabla J(\pi_{\theta^k}) \\ \text{where } \eta & = \sqrt{\frac{2 \delta}{\nabla J(\pi_{\theta^k})^\top F_{\theta^k}^{-1} \nabla J(\pi_{\theta^k})}} \end{aligned}

This gives us the closed-form update. Now the only challenge is to estimate the Fisher information matrix, since, as with the KL divergence constraint, it is an expectation over trajectories, and computing it exactly is therefore typically intractable.

As you can see, the NPG is the “basic” policy gradient algorithm we saw above, but with the gradient transformed by the inverse Fisher information matrix. This matrix can be understood as accounting for the geometry of the parameter space. The typical gradient descent algorithm implicitly measures distances between parameters using the typical Euclidean distance. Here, where the parameters map to a distribution, using the natural gradient update is equivalent to optimizing over distribution space rather than parameter space, where distance between distributions is measured by the Definition 6.3.

Though the NPG now gives a closed-form optimization step, it requires computing the inverse Fisher information matrix, which typically scales as O((dimΘ)3)O((\dim \Theta)^3). This can be expensive if the parameter space is large. Can we find an algorithm that works in linear time with respect to the dimension of the parameter space?

6.9Proximal policy optimization

We can relax the TRPO optimization problem in a different way: Rather than imposing a hard constraint on the KL distance, we can instead impose a soft constraint by incorporating it into the objective and penalizing parameter values that drastically change the trajectory distribution.

θk+1argmaxθEs0,,sH1ρπk[h=0H1Eahπθ(sh)Aπk(sh,ah)]λKL(ρθρθk)\begin{aligned} \theta^{k+1} &\gets \arg\max_{\theta} \E_{s_0, \dots, s_{H-1} \sim \rho_{\pi^{k}}} \left[ \sum_{\hi=0}^{\hor-1} \E_{a_\hi \sim \pi_{\theta}(s_\hi)} A^{\pi^{k}}(s_\hi, a_\hi) \right] - \lambda \kl{\rho_{\theta}}{\rho_{\theta^k}} \end{aligned}

Here λ is a regularization hyperparameter that controls the tradeoff between the two terms. This is the objective of the proximal policy optimization algorithm (Schulman et al. (2017)).

How do we solve this optimization? Let us begin by simplifying the KL(ρπkρπθ)\kl{\rho_{\pi^k}}{\rho_{\pi_{\theta}}} term. Expanding gives

KL(ρπkρπθ)=Eτρπk[logρπk(τ)ρπθ(τ)]=Eτρπk[h=0H1logπk(ahsh)πθ(ahsh)]state transitions cancel=Eτρπk[h=0H1log1πθ(ahsh)]+c\begin{aligned} \kl{\rho_{\pi^k}}{\rho_{\pi_{\theta}}} & = \E_{\tau \sim \rho_{\pi^k}} \left[\log \frac{\rho_{\pi^k}(\tau)}{\rho_{\pi_{\theta}}(\tau)}\right] \\ & = \E_{\tau \sim \rho_{\pi^k}} \left[ \sum_{h=0}^{H-1} \log \frac{\pi^k(a_\hi \mid s_\hi)}{\pi_{\theta}(a_\hi \mid s_\hi)}\right] & \text{state transitions cancel} \\ & = \E_{\tau \sim \rho_{\pi^k}} \left[ \sum_{h=0}^{H-1} \log \frac{1}{\pi_{\theta}(a_\hi \mid s_\hi)}\right] + c \end{aligned}

where cc is some constant with respect to θ, and can be ignored. This gives the objective

k(θ)=Es0,,sH1ρπk[h=0H1Eahπθ(sh)Aπk(sh,ah)]λEτρπk[h=0H1log1πθ(ahsh)]\ell^k(\theta) = \E_{s_0, \dots, s_{H-1} \sim \rho_{\pi^{k}}} \left[ \sum_{\hi=0}^{\hor-1} \E_{a_\hi \sim \pi_{\theta}(s_\hi)} A^{\pi^{k}}(s_\hi, a_\hi) \right] - \lambda \E_{\tau \sim \rho_{\pi^k}} \left[ \sum_{h=0}^{H-1} \log \frac{1}{\pi_{\theta}(a_\hi \mid s_\hi)}\right]

Once again, this takes an expectation over trajectories. But here we cannot directly sample trajectories from πk\pi^k, since in the first term, the actions actually come from πθ\pi_\theta. To make this term line up with the other expectation, we would need the actions to also come from πk\pi^k.

This should sound familiar: we want to estimate an expectation over one distribution by sampling from another. We can once again use Section 6.3.2 to rewrite the inner expectation:

Eahπθ(sh)Aπk(sh,ah)=Eahπk(sh)πθ(ahsh)πk(ahsh)Aπk(sh,ah)\E_{a_\hi \sim \pi_{\theta}(s_\hi)} A^{\pi^{k}}(s_\hi, a_\hi) = \E_{a_\hi \sim \pi^k(s_\hi)} \frac{\pi_\theta(a_\hi \mid s_\hi)}{\pi^k(a_\hi \mid s_\hi)} A^{\pi^{k}}(s_\hi, a_\hi)

Now we can combine the expectations together to get the objective

k(θ)=Eτρπk[h=0H1(πθ(ahsh)πk(ahsh)Aπk(sh,ah)λlog1πθ(ahsh))]\ell^k(\theta) = \E_{\tau \sim \rho_{\pi^k}} \left[ \sum_{h=0}^{H-1} \left( \frac{\pi_\theta(a_\hi \mid s_\hi)}{\pi^k(a_\hi \mid s_\hi)} A^{\pi^k}(s_\hi, a_\hi) - \lambda \log \frac{1}{\pi_\theta(a_\hi \mid s_\hi)} \right) \right]

Now we can estimate this function by a sample average over trajectories from πk\pi^k. Remember that to complete a single iteration of PPO, we execute

θk+1argmaxθk(θ).\theta^{k+1} \gets \arg\max_{\theta} \ell^k(\theta).

If k\ell^k is differentiable, we can optimize it by gradient ascent, completing a single iteration of PPO.

from typing import TypeVar

State = TypeVar("State")
Action = TypeVar("Action")

def ppo(
    env,
    pi: Callable[[Float[Array, " D"]], Callable[[State, Action], float]],
    λ: float,
    theta_init: Float[Array, " D"],
    n_iters: int,
    n_fit_trajectories: int,
    n_sample_trajectories: int,
):
    theta = theta_init
    for k in range(n_iters):
        fit_trajectories = sample_trajectories(env, pi(theta), n_fit_trajectories)
        A_hat = fit(fit_trajectories)

        sample_trajectories = sample_trajectories(env, pi(theta), n_sample_trajectories)
        
        def objective(theta_opt):
            total_objective = 0
            for tau in sample_trajectories:
                for s, a, _r in tau:
                    total_objective += pi(theta_opt)(s, a) / pi(theta)(s, a) * A_hat(s, a) + λ * jnp.log(pi(theta_opt)(s, a))
            return total_objective / n_sample_trajectories
        
        theta = optimize(objective, theta)

    return theta

latex(ppo)

6.10Summary

Policy gradient methods are a powerful family of algorithms that directly optimize the expected total reward by iteratively updating the policy parameters. Precisely, we estimate the gradient of the expected total reward (with respect to the parameters), and update the parameters in that direction. But estimating the gradient is a tricky task! We saw many ways to reduce the variance of the gradient estimator, culminating in the advantage-based expression (6.29).

But updating the parameters doesn’t entirely solve the problem: Sometimes, a small step in the parameters might lead to a big step in the policy. To avoid changing the policy too much at each step, we must account for the curvature in the parameter space. We first did this explicitly with , and then saw ways to relax the constraint in Definition 6.5 and Section 6.9.

These are still popular methods to this day, especially because they efficiently integrate with deep neural networks for representing complex functions.

References
  1. Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2018). Automatic Differentiation in Machine Learning: A Survey. arXiv. 10.48550/arXiv.1502.05767
  2. Williams, R. J. (1992). Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning. Machine Learning, 8(3), 229–256. 10.1007/BF00992696
  3. Boyd, S., & Vandenberghe, L. (2004). Convex Optimization. Cambridge University Press.
  4. Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal Policy Optimization Algorithms. arXiv. 10.48550/arXiv.1707.06347