Skip to article frontmatterSkip to article content

5 Fitted Dynamic Programming Algorithms

5.1Introduction

We borrow these definitions from the 1 Markov Decision Processes chapter:

from typing import NamedTuple, Callable, Optional
from jaxtyping import Float, Array
import jax.numpy as np
from jax import grad, vmap
import jax.random as rand
from tqdm import tqdm
import gymnasium as gym

key = rand.PRNGKey(184)


class Transition(NamedTuple):
    s: int
    a: int
    r: float


Trajectory = list[Transition]


def get_num_actions(trajectories: list[Trajectory]) -> int:
    """Get the number of actions in the dataset. Assumes actions range from 0 to A-1."""
    return max(max(t.a for t in τ) for τ in trajectories) + 1


State = Float[Array, "..."]  # arbitrary shape

# assume finite `A` actions and f outputs an array of Q-values
# i.e. Q(s, a, h) is implemented as f(s, h)[a]
QFunction = Callable[[State, int], Float[Array, " A"]]


def Q_zero(A: int) -> QFunction:
    """A Q-function that always returns zero."""
    return lambda s, a: np.zeros(A)


# a deterministic time-dependent policy
Policy = Callable[[State, int], int]


def q_to_greedy(Q: QFunction) -> Policy:
    """Get the greedy policy for the given state-action value function."""
    return lambda s, h: np.argmax(Q(s, h))

The 1 Markov Decision Processes chapter discussed the case of finite MDPs, where the state and action spaces S\mathcal{S} and A\mathcal{A} were finite. This gave us a closed-form expression for computing the r.h.s. of the Bellman one-step consistency equation. In this chapter, we consider the case of large or continuous state spaces, where the state space is too large to be enumerated. In this case, we need to approximate the value function and Q-function using methods from supervised learning.

We will first take a quick detour to introduce the empirical risk minimization framework for function approximation. We will then see its application to fitted RL algorithms, which attempt to learn the optimal value function (and the optimal policy) from a dataset of trajectories.

5.2Empirical risk minimization

The supervised learning task is as follows: We seek to learn the relationship between some input variables xx and some output variable yy (drawn from their joint distribution). Precisely, we want to find a function f^:xy\hat f : x \mapsto y that minimizes the squared error of the prediction:

f^=argminfE[(yf(x))2]\hat f = \arg\min_{f} \E[(y - f(x))^2]

An equivalent framing is that we seek to approximate the conditional expectation of yy given xx:

In most applications, the joint distribution of x,yx, y is unknown or extremely complex, and so we can’t analytically evaluate E[yx]\E [y \mid x]. Instead, our strategy is to draw NN samples (xi,yi)(x_i, y_i) from the joint distribution of xx and yy, and then use the sample average i=1N(yif(xi))2/N\sum_{i=1}^N (y_i - f(x_i))^2 / N to approximate the mean squared error. Then we use a fitting method to find a function f^\hat f that minimizes this objective and thus approximates the conditional expectation. This approach is called empirical risk minimization.

5.3Fitted value iteration

Let us apply ERM to the RL problem of computing the optimal policy / value function.

How did we compute the optimal value function in MDPs with finite state and action spaces?

  • In a [](#finite-horizon MDP <finite_horizon_mdps>), we can use dynamic programming, working backwards from the end of the time horizon, to compute the optimal value function exactly.

  • In an [](#infinite-horizon MDP <infinite_horizon_mdps>), we can use [](#value iteration <value_iteration>), which iterates the Bellman optimality operator (1.54) to approximately compute the optimal value function.

Our existing approaches represent the value function, and the MDP itself, in matrix notation. But what happens if the state space is extremely large, or even infinite (e.g. real-valued)? Then computing a weighted sum over all possible next states, which is required to compute the Bellman operator, becomes intractable.

Instead, we will need to use function approximation methods from supervised learning to solve for the value function in an alternative way.

In particular, suppose we have a dataset of NN trajectories τ1,,τNρπ\tau_1, \dots, \tau_N \sim \rho_{\pi} from some policy π (called the data collection policy) acting in the MDP of interest. Let us indicate the trajectory index in the superscript, so that

τi={s0i,a0i,r0i,s1i,a1i,r1i,,sH1i,aH1i,rH1i}.\tau_i = \{ s_0^i, a_0^i, r_0^i, s_1^i, a_1^i, r_1^i, \dots, s_{\hor-1}^i, a_{\hor-1}^i, r_{\hor-1}^i \}.
def collect_data(
    env: gym.Env, N: int, H: int, key: rand.PRNGKey, π: Optional[Policy] = None
) -> list[Trajectory]:
    """Collect a dataset of trajectories from the given policy (or a random one)."""
    trajectories = []
    seeds = [rand.bits(k).item() for k in rand.split(key, N)]
    for i in tqdm(range(N)):
        τ = []
        s, _ = env.reset(seed=seeds[i])
        for h in range(H):
            # sample from a random policy
            a = π(s, h) if π else env.action_space.sample()
            s_next, r, terminated, truncated, _ = env.step(a)
            τ.append(Transition(s, a, r))
            if terminated or truncated:
                break
            s = s_next
        trajectories.append(τ)
    return trajectories
env = gym.make("LunarLander-v2")
trajectories = collect_data(env, 100, 300, key)
trajectories[0][:5]  # show first five transitions from first trajectory
/Users/adzcai/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:517: DeprecationWarning: WARN: The environment LunarLander-v2 is out of date. You should consider upgrading to version `v3`.
  logger.deprecation(
---------------------------------------------------------------------------
DeprecatedEnv                             Traceback (most recent call last)
Cell In[3], line 1
----> 1 env = gym.make("LunarLander-v2")
      2 trajectories = collect_data(env, 100, 300, key)
      3 trajectories[0][:5]  # show first five transitions from first trajectory

File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:687, in make(id, max_episode_steps, disable_env_checker, **kwargs)
    684     assert isinstance(id, str)
    686     # The environment name can include an unloaded module in "module:env_name" style
--> 687     env_spec = _find_spec(id)
    689 assert isinstance(env_spec, EnvSpec)
    691 # Update the env spec kwargs with the `make` kwargs

File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:531, in _find_spec(env_id)
    525     logger.warn(
    526         f"Using the latest versioned environment `{new_env_id}` "
    527         f"instead of the unversioned environment `{env_name}`."
    528     )
    530 if env_spec is None:
--> 531     _check_version_exists(ns, name, version)
    532     raise error.Error(
    533         f"No registered env with id: {env_name}. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments."
    534     )
    536 return env_spec

File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:431, in _check_version_exists(ns, name, version)
    428     raise error.VersionNotFound(message)
    430 if latest_spec is not None and version < latest_spec.version:
--> 431     raise error.DeprecatedEnv(
    432         f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
    433         f"Please use `{latest_spec.id}` instead."
    434     )

DeprecatedEnv: Environment version v2 for `LunarLander` is deprecated. Please use `LunarLander-v3` instead.

Can we view the dataset of trajectories as a “labelled dataset” in order to apply supervised learning to approximate the optimal Q-function? Yes! Recall that we can characterize the optimal Q-function using the Bellman optimality equations, which don’t depend on an actual policy:

Qh(s,a)=r(s,a)+EsP(s,a)[maxaQh+1(s,a)]Q_\hi^\star(s, a) = r(s, a) + \E_{s' \sim P(s, a)} [\max_{a'} Q_{\hi+1}^\star(s', a')]

We can think of the arguments to the Q-function -- i.e. the current state, action, and timestep h\hi -- as the inputs xx, and the r.h.s. of the above equation as the label f(x)f(x). Note that the r.h.s. can also be expressed as a conditional expectation:

f(x)=E[yx]wherey=r(sh,ah)+maxaQh+1(s,a).f(x) = \E [y \mid x] \quad \text{where} \quad y = r(s_\hi, a_\hi) + \max_{a'} Q^\star_{\hi + 1}(s', a').

Approximating the conditional expectation is precisely the task that Section 5.2 is suited for!

Our above dataset would give us NHN \cdot \hor samples in the dataset:

xih=(shi,ahi,h)yih=r(shi,ahi)+maxaQh+1(sh+1i,a)x_{i \hi} = (s_\hi^i, a_\hi^i, \hi) \qquad y_{i \hi} = r(s_\hi^i, a_\hi^i) + \max_{a'} Q^\star_{\hi + 1}(s_{\hi + 1}^i, a')
def get_X(trajectories: list[Trajectory]):
    """
    We pass the state and timestep as input to the Q-function
    and return an array of Q-values.
    """
    rows = [(τ[h].s, τ[h].a, h) for τ in trajectories for h in range(len(τ))]
    return [np.stack(ary) for ary in zip(*rows)]


def get_y(
    trajectories: list[Trajectory],
    f: Optional[QFunction] = None,
    π: Optional[Policy] = None,
):
    """
    Transform the dataset of trajectories into a dataset for supervised learning.
    If `π` is None, instead estimates the optimal Q function.
    Otherwise, estimates the Q function of π.
    """
    f = f or Q_zero(get_num_actions(trajectories))
    y = []
    for τ in trajectories:
        for h in range(len(τ) - 1):
            s, a, r = τ[h]
            Q_values = f(s, h + 1)
            y.append(r + (Q_values[π(s, h + 1)] if π else Q_values.max()))
        y.append(τ[-1].r)
    return np.array(y)
s, a, h = get_X(trajectories[:1])
print("states:", s[:5])
print("actions:", a[:5])
print("timesteps:", h[:5])
get_y(trajectories[:1])[:5]

Then we can use empirical risk minimization to find a function f^\hat f that approximates the optimal Q-function.

# We will see some examples of fitting methods in the next section
FittingMethod = Callable[[Float[Array, "N D"], Float[Array, " N"]], QFunction]

But notice that the definition of yihy_{i \hi} depends on the Q-function itself! How can we resolve this circular dependency? Recall that we faced the same issue when evaluating a policy in an infinite-horizon MDP. There, we iterated the Definition 1.8 since we knew that the policy’s value function was a fixed point of the policy’s Bellman operator. We can apply the same strategy here, using the f^\hat f from the previous iteration to compute the labels yihy_{i \hi}, and then using this new dataset to fit the next iterate.

def fitted_q_iteration(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    epochs: int,
    Q_init: Optional[QFunction] = None,
) -> QFunction:
    """
    Run fitted Q-function iteration using the given dataset.
    Returns an estimate of the optimal Q-function.
    """
    Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
    X = get_X(trajectories)
    for _ in range(epochs):
        y = get_y(trajectories, Q_hat)
        Q_hat = fit(X, y)
    return Q_hat

5.4Fitted policy evaluation

We can also use this fixed-point interation to evaluate a policy using the dataset (not necessarily the one used to generate the trajectories):

def fitted_evaluation(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    π: Policy,
    epochs: int,
    Q_init: Optional[QFunction] = None,
) -> QFunction:
    """
    Run fitted policy evaluation using the given dataset.
    Returns an estimate of the Q-function of the given policy.
    """
    Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
    X = get_X(trajectories)
    for _ in tqdm(range(epochs)):
        y = get_y(trajectories, Q_hat, π)
        Q_hat = fit(X, y)
    return Q_hat

5.5Fitted policy iteration

We can use this policy evaluation algorithm to adapt the [](#policy iteration algorithm <policy_iteration>) to this new setting. The algorithm remains exactly the same -- repeatedly make the policy greedy w.r.t. its own value function -- except now we must evaluate the policy (i.e. compute its value function) using the iterative fitted_evaluation algorithm.

def fitted_policy_iteration(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    epochs: int,
    evaluation_epochs: int,
    π_init: Optional[Policy] = lambda s, h: 0,  # constant zero policy
):
    """Run fitted policy iteration using the given dataset."""
    π = π_init
    for _ in range(epochs):
        Q_hat = fitted_evaluation(trajectories, fit, π, evaluation_epochs)
        π = q_to_greedy(Q_hat)
    return π

5.6Summary