5.1Introduction¶
We borrow these definitions from the 1 Markov Decision Processes chapter:
from typing import NamedTuple, Callable, Optional
from jaxtyping import Float, Array
import jax.numpy as np
from jax import grad, vmap
import jax.random as rand
from tqdm import tqdm
import gymnasium as gym
key = rand.PRNGKey(184)
class Transition(NamedTuple):
s: int
a: int
r: float
Trajectory = list[Transition]
def get_num_actions(trajectories: list[Trajectory]) -> int:
"""Get the number of actions in the dataset. Assumes actions range from 0 to A-1."""
return max(max(t.a for t in τ) for τ in trajectories) + 1
State = Float[Array, "..."] # arbitrary shape
# assume finite `A` actions and f outputs an array of Q-values
# i.e. Q(s, a, h) is implemented as f(s, h)[a]
QFunction = Callable[[State, int], Float[Array, " A"]]
def Q_zero(A: int) -> QFunction:
"""A Q-function that always returns zero."""
return lambda s, a: np.zeros(A)
# a deterministic time-dependent policy
Policy = Callable[[State, int], int]
def q_to_greedy(Q: QFunction) -> Policy:
"""Get the greedy policy for the given state-action value function."""
return lambda s, h: np.argmax(Q(s, h))
The 1 Markov Decision Processes chapter discussed the case of finite MDPs, where the state and action spaces and were finite. This gave us a closed-form expression for computing the r.h.s. of the Bellman one-step consistency equation. In this chapter, we consider the case of large or continuous state spaces, where the state space is too large to be enumerated. In this case, we need to approximate the value function and Q-function using methods from supervised learning.
We will first take a quick detour to introduce the empirical risk minimization framework for function approximation. We will then see its application to fitted RL algorithms, which attempt to learn the optimal value function (and the optimal policy) from a dataset of trajectories.
5.2Empirical risk minimization¶
The supervised learning task is as follows: We seek to learn the relationship between some input variables and some output variable (drawn from their joint distribution). Precisely, we want to find a function that minimizes the squared error of the prediction:
An equivalent framing is that we seek to approximate the conditional expectation of given :
In most applications, the joint distribution of is unknown or extremely complex, and so we can’t analytically evaluate . Instead, our strategy is to draw samples from the joint distribution of and , and then use the sample average to approximate the mean squared error. Then we use a fitting method to find a function that minimizes this objective and thus approximates the conditional expectation. This approach is called empirical risk minimization.
Given a dataset of samples , empirical risk minimization seeks to find a function (from some class of functions ) that minimizes the empirical risk:
We will cover the details of the minimization process in [](#the next section <supervised_learning>).
Why is it important that we constrain our search to a class of functions ?
Hint: Consider the function . What is the empirical risk of this function? Would you consider it a good approximation of the conditional expectation?
5.3Fitted value iteration¶
Let us apply ERM to the RL problem of computing the optimal policy / value function.
How did we compute the optimal value function in MDPs with finite state and action spaces?
In a [](#finite-horizon MDP <finite_horizon_mdps>), we can use dynamic programming, working backwards from the end of the time horizon, to compute the optimal value function exactly.
In an [](#infinite-horizon MDP <infinite_horizon_mdps>), we can use [](#value iteration <value_iteration>), which iterates the Bellman optimality operator (1.54) to approximately compute the optimal value function.
Our existing approaches represent the value function, and the MDP itself, in matrix notation. But what happens if the state space is extremely large, or even infinite (e.g. real-valued)? Then computing a weighted sum over all possible next states, which is required to compute the Bellman operator, becomes intractable.
Instead, we will need to use function approximation methods from supervised learning to solve for the value function in an alternative way.
In particular, suppose we have a dataset of trajectories from some policy π (called the data collection policy) acting in the MDP of interest. Let us indicate the trajectory index in the superscript, so that
def collect_data(
env: gym.Env, N: int, H: int, key: rand.PRNGKey, π: Optional[Policy] = None
) -> list[Trajectory]:
"""Collect a dataset of trajectories from the given policy (or a random one)."""
trajectories = []
seeds = [rand.bits(k).item() for k in rand.split(key, N)]
for i in tqdm(range(N)):
τ = []
s, _ = env.reset(seed=seeds[i])
for h in range(H):
# sample from a random policy
a = π(s, h) if π else env.action_space.sample()
s_next, r, terminated, truncated, _ = env.step(a)
τ.append(Transition(s, a, r))
if terminated or truncated:
break
s = s_next
trajectories.append(τ)
return trajectories
env = gym.make("LunarLander-v2")
trajectories = collect_data(env, 100, 300, key)
trajectories[0][:5] # show first five transitions from first trajectory
/Users/adzcai/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:517: DeprecationWarning: WARN: The environment LunarLander-v2 is out of date. You should consider upgrading to version `v3`.
logger.deprecation(
---------------------------------------------------------------------------
DeprecatedEnv Traceback (most recent call last)
Cell In[3], line 1
----> 1 env = gym.make("LunarLander-v2")
2 trajectories = collect_data(env, 100, 300, key)
3 trajectories[0][:5] # show first five transitions from first trajectory
File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:687, in make(id, max_episode_steps, disable_env_checker, **kwargs)
684 assert isinstance(id, str)
686 # The environment name can include an unloaded module in "module:env_name" style
--> 687 env_spec = _find_spec(id)
689 assert isinstance(env_spec, EnvSpec)
691 # Update the env spec kwargs with the `make` kwargs
File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:531, in _find_spec(env_id)
525 logger.warn(
526 f"Using the latest versioned environment `{new_env_id}` "
527 f"instead of the unversioned environment `{env_name}`."
528 )
530 if env_spec is None:
--> 531 _check_version_exists(ns, name, version)
532 raise error.Error(
533 f"No registered env with id: {env_name}. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments."
534 )
536 return env_spec
File ~/micromamba/envs/rlbook/lib/python3.11/site-packages/gymnasium/envs/registration.py:431, in _check_version_exists(ns, name, version)
428 raise error.VersionNotFound(message)
430 if latest_spec is not None and version < latest_spec.version:
--> 431 raise error.DeprecatedEnv(
432 f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
433 f"Please use `{latest_spec.id}` instead."
434 )
DeprecatedEnv: Environment version v2 for `LunarLander` is deprecated. Please use `LunarLander-v3` instead.
Can we view the dataset of trajectories as a “labelled dataset” in order to apply supervised learning to approximate the optimal Q-function? Yes! Recall that we can characterize the optimal Q-function using the Bellman optimality equations, which don’t depend on an actual policy:
We can think of the arguments to the Q-function -- i.e. the current state, action, and timestep -- as the inputs , and the r.h.s. of the above equation as the label . Note that the r.h.s. can also be expressed as a conditional expectation:
Approximating the conditional expectation is precisely the task that Section 5.2 is suited for!
Our above dataset would give us samples in the dataset:
def get_X(trajectories: list[Trajectory]):
"""
We pass the state and timestep as input to the Q-function
and return an array of Q-values.
"""
rows = [(τ[h].s, τ[h].a, h) for τ in trajectories for h in range(len(τ))]
return [np.stack(ary) for ary in zip(*rows)]
def get_y(
trajectories: list[Trajectory],
f: Optional[QFunction] = None,
π: Optional[Policy] = None,
):
"""
Transform the dataset of trajectories into a dataset for supervised learning.
If `π` is None, instead estimates the optimal Q function.
Otherwise, estimates the Q function of π.
"""
f = f or Q_zero(get_num_actions(trajectories))
y = []
for τ in trajectories:
for h in range(len(τ) - 1):
s, a, r = τ[h]
Q_values = f(s, h + 1)
y.append(r + (Q_values[π(s, h + 1)] if π else Q_values.max()))
y.append(τ[-1].r)
return np.array(y)
s, a, h = get_X(trajectories[:1])
print("states:", s[:5])
print("actions:", a[:5])
print("timesteps:", h[:5])
get_y(trajectories[:1])[:5]
Then we can use empirical risk minimization to find a function that approximates the optimal Q-function.
# We will see some examples of fitting methods in the next section
FittingMethod = Callable[[Float[Array, "N D"], Float[Array, " N"]], QFunction]
But notice that the definition of depends on the Q-function itself! How can we resolve this circular dependency? Recall that we faced the same issue when evaluating a policy in an infinite-horizon MDP. There, we iterated the Definition 1.8 since we knew that the policy’s value function was a fixed point of the policy’s Bellman operator. We can apply the same strategy here, using the from the previous iteration to compute the labels , and then using this new dataset to fit the next iterate.
def fitted_q_iteration(
trajectories: list[Trajectory],
fit: FittingMethod,
epochs: int,
Q_init: Optional[QFunction] = None,
) -> QFunction:
"""
Run fitted Q-function iteration using the given dataset.
Returns an estimate of the optimal Q-function.
"""
Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
X = get_X(trajectories)
for _ in range(epochs):
y = get_y(trajectories, Q_hat)
Q_hat = fit(X, y)
return Q_hat
5.4Fitted policy evaluation¶
We can also use this fixed-point interation to evaluate a policy using the dataset (not necessarily the one used to generate the trajectories):
Input: Policy to be evaluated.
Output: An approximation of the value function of the policy.
- Initialize some function .
- Iterate the following:
Generate a supervised learning dataset from the trajectories and the current estimate , where the labels come from the r.h.s. of the Bellman consistency equation for the given policy.
Set to the function that minimizes the empirical risk:
def fitted_evaluation(
trajectories: list[Trajectory],
fit: FittingMethod,
π: Policy,
epochs: int,
Q_init: Optional[QFunction] = None,
) -> QFunction:
"""
Run fitted policy evaluation using the given dataset.
Returns an estimate of the Q-function of the given policy.
"""
Q_hat = Q_init or Q_zero(get_num_actions(trajectories))
X = get_X(trajectories)
for _ in tqdm(range(epochs)):
y = get_y(trajectories, Q_hat, π)
Q_hat = fit(X, y)
return Q_hat
Spot the difference between fitted_evaluation
and fitted_q_iteration
. (See the definition of get_y
.)
How would you modify this algorithm to evaluate the data collection policy?
5.5Fitted policy iteration¶
We can use this policy evaluation algorithm to adapt the [](#policy iteration algorithm <policy_iteration>) to this new setting. The algorithm remains exactly the same -- repeatedly make the policy greedy w.r.t. its own value function -- except now we must evaluate the policy (i.e. compute its value function) using the iterative fitted_evaluation
algorithm.
def fitted_policy_iteration(
trajectories: list[Trajectory],
fit: FittingMethod,
epochs: int,
evaluation_epochs: int,
π_init: Optional[Policy] = lambda s, h: 0, # constant zero policy
):
"""Run fitted policy iteration using the given dataset."""
π = π_init
for _ in range(epochs):
Q_hat = fitted_evaluation(trajectories, fit, π, evaluation_epochs)
π = q_to_greedy(Q_hat)
return π