We borrow these definitions from the Chapter 1 chapter:
Code
from utils import gym, tqdm, rand, Float, Array, NamedTuple, Callable, Optional, npkey = rand.PRNGKey(184)class Transition(NamedTuple): s: int a: int r: floatTrajectory =list[Transition]def get_num_actions(trajectories: list[Trajectory]) ->int:"""Get the number of actions in the dataset. Assumes actions range from 0 to A-1."""returnmax(max(t.a for t in τ) for τ in trajectories) +1State = Float[Array, "..."] # arbitrary shape# assume finite `A` actions and f outputs an array of Q-values# i.e. Q(s, a, h) is implemented as f(s, h)[a]QFunction = Callable[[State, int], Float[Array, " A"]]def Q_zero(A: int) -> QFunction:"""A Q-function that always returns zero."""returnlambda s, a: np.zeros(A)# a deterministic time-dependent policyPolicy = Callable[[State, int], int]def q_to_greedy(Q: QFunction) -> Policy:"""Get the greedy policy for the given state-action value function."""returnlambda s, h: np.argmax(Q(s, h))
The Chapter 1 chapter discussed the case of finite MDPs, where the state and action spaces \(\mathcal{S}\) and \(\mathcal{A}\) were finite. This gave us a closed-form expression for computing the r.h.s. of Theorem 1.1. In this chapter, we consider the case of large or continuous state spaces, where the state space is too large to be enumerated. In this case, we need to approximate the value function and Q-function using methods from supervised learning.
5.2 Fitted value iteration
Let us apply ERM to the RL problem of computing the optimal policy / value function.
How did we compute the optimal value function in MDPs with finite state and action spaces?
In a Section 1.2, we can use Definition 1.11, working backwards from the end of the time horizon, to compute the optimal value function exactly.
In an Section 1.3, we can use Section 1.3.7.1, which iterates the Bellman optimality operator Equation 1.19 to approximately compute the optimal value function.
Our existing approaches represent the value function, and the MDP itself, in matrix notation. But what happens if the state space is extremely large, or even infinite (e.g. real-valued)? Then computing a weighted sum over all possible next states, which is required to compute the Bellman operator, becomes intractable.
Instead, we will need to use function approximation methods from supervised learning to solve for the value function in an alternative way.
In particular, suppose we have a dataset of \(N\) trajectories \(\tau_1, \dots, \tau_N \sim \rho_{\pi}\) from some policy \(\pi\) (called the data collection policy) acting in the MDP of interest. Let us indicate the trajectory index in the superscript, so that
def collect_data( env: gym.Env, N: int, H: int, key: rand.PRNGKey, π: Optional[Policy] =None) ->list[Trajectory]:"""Collect a dataset of trajectories from the given policy (or a random one).""" trajectories = [] seeds = [rand.bits(k).item() for k in rand.split(key, N)]for i in tqdm(range(N)): τ = [] s, _ = env.reset(seed=seeds[i])for h inrange(H):# sample from a random policy a = π(s, h) if π else env.action_space.sample() s_next, r, terminated, truncated, _ = env.step(a) τ.append(Transition(s, a, r))if terminated or truncated:break s = s_next trajectories.append(τ)return trajectories
Code
env = gym.make("LunarLander-v3")trajectories = collect_data(env, 100, 300, key)trajectories[0][:5] # show first five transitions from first trajectory
Can we view the dataset of trajectories as a “labelled dataset” in order to apply supervised learning to approximate the optimal Q-function? Yes! Recall that we can characterize the optimal Q-function using the Theorem 1.3, which don’t depend on an actual policy:
\[
Q_h^\star(s, a) = r(s, a) + \mathbb{E}_{s' \sim P(s, a)} [\max_{a'} Q_{h+1}^\star(s', a')]
\]
We can think of the arguments to the Q-function – i.e. the current state, action, and timestep \(h\) – as the inputs \(x\), and the r.h.s. of the above equation as the label \(f(x)\). Note that the r.h.s. can also be expressed as a conditional expectation:
def get_X(trajectories: list[Trajectory]):""" We pass the state and timestep as input to the Q-function and return an array of Q-values. """ rows = [(τ[h].s, τ[h].a, h) for τ in trajectories for h inrange(len(τ))]return [np.stack(ary) for ary inzip(*rows)]def get_y( trajectories: list[Trajectory], f: Optional[QFunction] =None, π: Optional[Policy] =None,):""" Transform the dataset of trajectories into a dataset for supervised learning. If `π` is None, instead estimates the optimal Q function. Otherwise, estimates the Q function of π. """ f = f or Q_zero(get_num_actions(trajectories)) y = []for τ in trajectories:for h inrange(len(τ) -1): s, a, r = τ[h] Q_values = f(s, h +1) y.append(r + (Q_values[π(s, h +1)] if π else Q_values.max())) y.append(τ[-1].r)return np.array(y)
Code
s, a, h = get_X(trajectories[:1])print("states:", s[:5])print("actions:", a[:5])print("timesteps:", h[:5])
Then we can use empirical risk minimization to find a function \(\hat f\) that approximates the optimal Q-function.
Code
# We will see some examples of fitting methods in the next sectionFittingMethod = Callable[[Float[Array, "N D"], Float[Array, " N"]], QFunction]
But notice that the definition of \(y_{i h}\) depends on the Q-function itself! How can we resolve this circular dependency? Recall that we faced the same issue when evaluating a policy in an infinite-horizon MDP (Section 1.3.6.1). There, we iterated the Definition 1.7 since we knew that the policy’s value function was a fixed point of the policy’s Bellman operator. We can apply the same strategy here, using the \(\hat f\) from the previous iteration to compute the labels \(y_{i h}\), and then using this new dataset to fit the next iterate.
Definition 5.1 (Fitted Q-function iteration)
Initialize some function \(\hat f(s, a, h) \in \mathbb{R}\).
Iterate the following:
Generate a supervised learning dataset \(X, y\) from the trajectories and the current estimate \(f\), where the labels come from the r.h.s. of the Bellman optimality operator Equation 1.19
Set \(\hat f\) to the function that minimizes the empirical risk:
def fitted_q_iteration( trajectories: list[Trajectory], fit: FittingMethod, epochs: int, Q_init: Optional[QFunction] =None,) -> QFunction:""" Run fitted Q-function iteration using the given dataset. Returns an estimate of the optimal Q-function. """ Q_hat = Q_init or Q_zero(get_num_actions(trajectories)) X = get_X(trajectories)for _ inrange(epochs): y = get_y(trajectories, Q_hat) Q_hat = fit(X, y)return Q_hat
5.3 Fitted policy evaluation
We can also use this fixed-point interation to evaluate a policy using the dataset (not necessarily the one used to generate the trajectories):
Definition 5.2 (Fitted policy evaluation)Input: Policy \(\pi : \mathcal{S} \times [H] \to \Delta(\mathcal{A})\) to be evaluated.
Output: An approximation of the value function \(Q^\pi\) of the policy.
Initialize some function \(\hat f(s, a, h) \in \mathbb{R}\).
Iterate the following:
Generate a supervised learning dataset \(X, y\) from the trajectories and the current estimate \(f\), where the labels come from the r.h.s. of the Theorem 1.1 for the given policy.
Set \(\hat f\) to the function that minimizes the empirical risk:
def fitted_evaluation( trajectories: list[Trajectory], fit: FittingMethod, π: Policy, epochs: int, Q_init: Optional[QFunction] =None,) -> QFunction:""" Run fitted policy evaluation using the given dataset. Returns an estimate of the Q-function of the given policy. """ Q_hat = Q_init or Q_zero(get_num_actions(trajectories)) X = get_X(trajectories)for _ in tqdm(range(epochs)): y = get_y(trajectories, Q_hat, π) Q_hat = fit(X, y)return Q_hat
Spot the difference between fitted_evaluation and fitted_q_iteration. (See the definition of get_y.) How would you modify this algorithm to evaluate the data collection policy?
5.4 Fitted policy iteration
We can use this policy evaluation algorithm to adapt Section 1.3.7.2 to this new setting. The algorithm remains exactly the same – repeatedly make the policy greedy w.r.t. its own value function – except now we must evaluate the policy (i.e. compute its value function) using the iterative fitted-evaluation algorithm.