5  Fitted Dynamic Programming Algorithms



5.1 Introduction

We borrow these definitions from the Chapter 1 chapter:

Code
from utils import gym, tqdm, rand, Float, Array, NamedTuple, Callable, Optional, np

key = rand.PRNGKey(184)

class Transition(NamedTuple):
    s: int
    a: int
    r: float


Trajectory = list[Transition]


def get_num_actions(trajectories: list[Trajectory]) -> int:
    """Get the number of actions in the dataset. Assumes actions range from 0 to A-1."""
    return max(max(t.a for t in τ) for τ in trajectories) + 1


State = Float[Array, "..."]  # arbitrary shape

# assume finite `A` actions and f outputs an array of Q-values
# i.e. Q(s, a, h) is implemented as f(s, h)[a]
QFunction = Callable[[State, int], Float[Array, " A"]]


def Q_zero(A: int) -> QFunction:
    """A Q-function that always returns zero."""
    return lambda s, a: np.zeros(A)


# a deterministic time-dependent policy
Policy = Callable[[State, int], int]


def q_to_greedy(Q: QFunction) -> Policy:
    """Get the greedy policy for the given state-action value function."""
    return lambda s, h: np.argmax(Q(s, h))

The Chapter 1 chapter discussed the case of finite MDPs, where the state and action spaces \(\mathcal{S}\) and \(\mathcal{A}\) were finite. This gave us a closed-form expression for computing the r.h.s. of Theorem 1.1. In this chapter, we consider the case of large or continuous state spaces, where the state space is too large to be enumerated. In this case, we need to approximate the value function and Q-function using methods from supervised learning.

5.2 Fitted value iteration

Let us apply ERM to the RL problem of computing the optimal policy / value function.

How did we compute the optimal value function in MDPs with finite state and action spaces?

Our existing approaches represent the value function, and the MDP itself, in matrix notation. But what happens if the state space is extremely large, or even infinite (e.g. real-valued)? Then computing a weighted sum over all possible next states, which is required to compute the Bellman operator, becomes intractable.

Instead, we will need to use function approximation methods from supervised learning to solve for the value function in an alternative way.

In particular, suppose we have a dataset of \(N\) trajectories \(\tau_1, \dots, \tau_N \sim \rho_{\pi}\) from some policy \(\pi\) (called the data collection policy) acting in the MDP of interest. Let us indicate the trajectory index in the superscript, so that

\[ \tau_i = \{ s_0^i, a_0^i, r_0^i, s_1^i, a_1^i, r_1^i, \dots, s_{H-1}^i, a_{H-1}^i, r_{H-1}^i \}. \]

Code
def collect_data(
    env: gym.Env, N: int, H: int, key: rand.PRNGKey, π: Optional[Policy] = None
) -> list[Trajectory]:
    """Collect a dataset of trajectories from the given policy (or a random one)."""
    trajectories = []
    seeds = [rand.bits(k).item() for k in rand.split(key, N)]
    for i in tqdm(range(N)):
        τ = []
        s, _ = env.reset(seed=seeds[i])
        for h in range(H):
            # sample from a random policy
            a = π(s, h) if π else env.action_space.sample()
            s_next, r, terminated, truncated, _ = env.step(a)
            τ.append(Transition(s, a, r))
            if terminated or truncated:
                break
            s = s_next
        trajectories.append(τ)
    return trajectories
Code
env = gym.make("LunarLander-v3")
trajectories = collect_data(env, 100, 300, key)
trajectories[0][:5]  # show first five transitions from first trajectory
  0%|          | 0/100 [00:00<?, ?it/s] 57%|█████▋    | 57/100 [00:00<00:00, 569.63it/s]100%|██████████| 100/100 [00:00<00:00, 556.57it/s]
[Transition(s=array([-0.00767412,  1.4020356 , -0.77731264, -0.3948966 ,  0.00889908,
         0.17607284,  0.        ,  0.        ], dtype=float32), a=np.int64(3), r=np.float64(0.01510256111183253)),
 Transition(s=array([-0.01526899,  1.392572  , -0.76625407, -0.42065707,  0.01559265,
         0.133885  ,  0.        ,  0.        ], dtype=float32), a=np.int64(0), r=np.float64(-0.9906062816993995)),
 Transition(s=array([-0.02286405,  1.3825084 , -0.76627475, -0.4473554 ,  0.02228237,
         0.13380662,  0.        ,  0.        ], dtype=float32), a=np.int64(0), r=np.float64(-0.9934907824199968)),
 Transition(s=array([-0.0304594 ,  1.3718452 , -0.7662946 , -0.47403094,  0.02897083,
         0.133782  ,  0.        ,  0.        ], dtype=float32), a=np.int64(2), r=np.float64(1.4450094337078212)),
 Transition(s=array([-0.03802614,  1.361714  , -0.76368487, -0.45042533,  0.03589971,
         0.13859043,  0.        ,  0.        ], dtype=float32), a=np.int64(0), r=np.float64(-1.0299876604044584))]

Can we view the dataset of trajectories as a “labelled dataset” in order to apply supervised learning to approximate the optimal Q-function? Yes! Recall that we can characterize the optimal Q-function using the Theorem 1.3, which don’t depend on an actual policy:

\[ Q_h^\star(s, a) = r(s, a) + \mathbb{E}_{s' \sim P(s, a)} [\max_{a'} Q_{h+1}^\star(s', a')] \]

We can think of the arguments to the Q-function – i.e. the current state, action, and timestep \(h\) – as the inputs \(x\), and the r.h.s. of the above equation as the label \(f(x)\). Note that the r.h.s. can also be expressed as a conditional expectation:

\[ f(x) = \mathbb{E}[y \mid x] \quad \text{where} \quad y = r(s_h, a_h) + \max_{a'} Q^\star_{h+ 1}(s', a'). \]

Approximating the conditional expectation is precisely the task that Section 4.3 is suited for!

Our above dataset would give us \(N \cdot H\) samples in the dataset:

\[ x_{i h} = (s_h^i, a_h^i, h) \qquad y_{i h} = r(s_h^i, a_h^i) + \max_{a'} Q^\star_{h+ 1}(s_{h+ 1}^i, a') \]

Code
def get_X(trajectories: list[Trajectory]):
    """
    We pass the state and timestep as input to the Q-function
    and return an array of Q-values.
    """
    rows = [(τ[h].s, τ[h].a, h) for τ in trajectories for h in range(len(τ))]
    return [np.stack(ary) for ary in zip(*rows)]


def get_y(
    trajectories: list[Trajectory],
    f: Optional[QFunction] = None,
    π: Optional[Policy] = None,
):
    """
    Transform the dataset of trajectories into a dataset for supervised learning.
    If `π` is None, instead estimates the optimal Q function.
    Otherwise, estimates the Q function of π.
    """
    f = f or Q_zero(get_num_actions(trajectories))
    y = []
    for τ in trajectories:
        for h in range(len(τ) - 1):
            s, a, r = τ[h]
            Q_values = f(s, h + 1)
            y.append(r + (Q_values[π(s, h + 1)] if π else Q_values.max()))
        y.append(τ[-1].r)
    return np.array(y)
Code
s, a, h = get_X(trajectories[:1])
print("states:", s[:5])
print("actions:", a[:5])
print("timesteps:", h[:5])
states: [[-0.00767412  1.4020356  -0.77731264 -0.3948966   0.00889908  0.17607284
   0.          0.        ]
 [-0.01526899  1.392572   -0.76625407 -0.42065707  0.01559265  0.133885
   0.          0.        ]
 [-0.02286405  1.3825084  -0.76627475 -0.4473554   0.02228237  0.13380662
   0.          0.        ]
 [-0.0304594   1.3718452  -0.7662946  -0.47403094  0.02897083  0.133782
   0.          0.        ]
 [-0.03802614  1.361714   -0.76368487 -0.45042533  0.03589971  0.13859043
   0.          0.        ]]
actions: [3 0 0 2 0]
timesteps: [0 1 2 3 4]
Code
get_y(trajectories[:1])[:5]
array([ 0.01510256, -0.99060628, -0.99349078,  1.44500943, -1.02998766])

Then we can use empirical risk minimization to find a function \(\hat f\) that approximates the optimal Q-function.

Code
# We will see some examples of fitting methods in the next section
FittingMethod = Callable[[Float[Array, "N D"], Float[Array, " N"]], QFunction]

But notice that the definition of \(y_{i h}\) depends on the Q-function itself! How can we resolve this circular dependency? Recall that we faced the same issue when evaluating a policy in an infinite-horizon MDP (Section 1.3.6.1). There, we iterated the Definition 1.7 since we knew that the policy’s value function was a fixed point of the policy’s Bellman operator. We can apply the same strategy here, using the \(\hat f\) from the previous iteration to compute the labels \(y_{i h}\), and then using this new dataset to fit the next iterate.

Definition 5.1 (Fitted Q-function iteration)  

  1. Initialize some function \(\hat f(s, a, h) \in \mathbb{R}\).
  2. Iterate the following:
    1. Generate a supervised learning dataset \(X, y\) from the trajectories and the current estimate \(f\), where the labels come from the r.h.s. of the Bellman optimality operator Equation 1.19
    2. Set \(\hat f\) to the function that minimizes the empirical risk:

\[ \hat f \gets \arg\min_f \frac{1}{N} \sum_{i=1}^N (y_i - f(x_i))^2. \]

Code
def fitted_q_iteration(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    epochs: int,
    Q_init: Optional[QFunction] = None,
) -> QFunction:
    """
    Run fitted Q-function iteration using the given dataset.
    Returns an estimate of the optimal Q-function.
    """
    Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
    X = get_X(trajectories)
    for _ in range(epochs):
        y = get_y(trajectories, Q_hat)
        Q_hat = fit(X, y)
    return Q_hat

5.3 Fitted policy evaluation

We can also use this fixed-point interation to evaluate a policy using the dataset (not necessarily the one used to generate the trajectories):

Definition 5.2 (Fitted policy evaluation) Input: Policy \(\pi : \mathcal{S} \times [H] \to \Delta(\mathcal{A})\) to be evaluated.

Output: An approximation of the value function \(Q^\pi\) of the policy.

  1. Initialize some function \(\hat f(s, a, h) \in \mathbb{R}\).
  2. Iterate the following:
    1. Generate a supervised learning dataset \(X, y\) from the trajectories and the current estimate \(f\), where the labels come from the r.h.s. of the Theorem 1.1 for the given policy.
    2. Set \(\hat f\) to the function that minimizes the empirical risk:

\[ \hat f \gets \arg\min_f \frac{1}{N} \sum_{i=1}^N (y_i - f(x_i))^2. \]

Code
def fitted_evaluation(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    π: Policy,
    epochs: int,
    Q_init: Optional[QFunction] = None,
) -> QFunction:
    """
    Run fitted policy evaluation using the given dataset.
    Returns an estimate of the Q-function of the given policy.
    """
    Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
    X = get_X(trajectories)
    for _ in tqdm(range(epochs)):
        y = get_y(trajectories, Q_hat, π)
        Q_hat = fit(X, y)
    return Q_hat

Spot the difference between fitted_evaluation and fitted_q_iteration. (See the definition of get_y.) How would you modify this algorithm to evaluate the data collection policy?

5.4 Fitted policy iteration

We can use this policy evaluation algorithm to adapt Section 1.3.7.2 to this new setting. The algorithm remains exactly the same – repeatedly make the policy greedy w.r.t. its own value function – except now we must evaluate the policy (i.e. compute its value function) using the iterative fitted-evaluation algorithm.

Code
def fitted_policy_iteration(
    trajectories: list[Trajectory],
    fit: FittingMethod,
    epochs: int,
    evaluation_epochs: int,
    π_init: Optional[Policy] = lambda s, h: 0,  # constant zero policy
):
    """Run fitted policy iteration using the given dataset."""
    π = π_init
    for _ in range(epochs):
        Q_hat = fitted_evaluation(trajectories, fit, π, evaluation_epochs)
        π = q_to_greedy(Q_hat)
    return π

5.5 Summary