8  Tree Search Methods



8.1 Introduction

Have you ever lost a strategy game against a skilled opponent? It probably seemed like they were ahead of you at every turn. They might have been planning ahead and anticipating your actions, then formulating their strategy to counter yours. If this opponent was a computer, they might have been using one of the strategies that we are about to explore.

Code
%load_ext autoreload
%autoreload 2
Code
from utils import Int, Array, latex, jnp, NamedTuple
from enum import IntEnum

8.2 Deterministic, zero sum, fully observable two-player games

In this chapter, we will focus on games that are:

  • deterministic,
  • zero sum (one player wins and the other loses),
  • fully observable, that is, the state of the game is perfectly known by both players,
  • for two players that alternate turns,

We can represent such a game as a complete game tree. Each possible state is a node in the tree, and since we only consider deterministic games, we can represent actions as edges leading from the current state to the next. Each path through the tree, from root to leaf, represents a single game.

The first two layers of the complete game tree of tic-tac-toe.

If you could store the complete game tree on a computer, you would be able to win every potentially winnable game by searching all paths from your current state and taking a winning move. We will see an explicit algorithm for this in Section 8.3. However, as games become more complex, it becomes computationally impossible to search every possible path.

For instance, a chess player has roughly 30 actions to choose from at each turn, and each game takes roughly 40 moves per player, so trying to solve chess exactly using minimax would take somewhere on the order of \(30^{80} \approx 10^{118}\) operations. That’s 10 billion billion billion billion billion billion billion billion billion billion billion billion billion operations. As of the time of writing, the fastest processor can achieve almost 10 GHz (10 billion operations per second), so to fully solve chess using minimax is many, many orders of magnitude out of reach.

It is thus intractable, in any realistic setting, to solve the complete game tree exactly. Luckily, only a small fraction of those games ever occur in reality; Later in this chapter, we will explore ways to prune away parts of the tree that we know we can safely ignore. We can also approximate the value of a state without fully evaluating it. Using these approximations, we can no longer guarantee winning the game, but we can come up with strategies that will do well against most opponents.

8.2.1 Notation

Let us now describe these games formally. We’ll call the first player Max and the second player Min. Max seeks to maximize the final game score, while Min seeks to minimize the final game score.

  • We’ll use \(\mathcal{S}\) to denote the set of all possible game states.
  • The game begins in some initial state \(s_0 \in \mathcal{S}\).
  • Max moves on even turn numbers \(h = 2n\), and Min moves on odd turn numbers \(h = 2n+1\), where \(n\) is a natural number.
  • The space of possible actions, \(\mathcal{A}_h(s)\), depends on the state itself, as well as whose turn it is. (For example, in tic-tac-toe, Max can only play Xs while Min can only play Os.)
  • The game ends after \(H\) total moves (which might be even or odd). We call the final state a terminal state.
  • \(P\) denotes the state transitions, that is, \(P(s, a)\) denotes the resulting state when taking action \(a \in \mathcal{A}(s)\) in state \(s\). We’ll assume that this function is time-homogeneous (a.k.a. stationary) and doesn’t change across timesteps.
  • \(r(s)\) denotes the game score of the terminal state \(s\). Note that this is some positive or negative value seen by both players: A positive value indicates Max winning, a negative value indicates Min winning, and a value of \(0\) indicates a tie.

We also call the sequence of states and actions a trajectory.

Above, we suppose that the game ends after \(H\) total moves. But most real games have a variable length. How would you describe this?

Example 8.1 (Tic-tac-toe) Let us frame tic-tac-toe in this setting.

  • Each of the \(9\) squares is either empty, marked X, or marked O. So there are \(|\mathcal{S}| = 3^9\) potential states. Not all of these may be reachable!
  • The initial state \(s_0\) is the empty board.
  • The set of possible actions for Max in state \(s\), \(\mathcal{A}_{2n}(s)\), is the set of tuples \((\text{``X''}, i)\) where \(i\) refers to an empty square in \(s\). Similarly, \(\mathcal{A}_{2n+1}(s)\) is the set of tuples \((\text{``O''}, i)\) where \(i\) refers to an empty square in \(s\).
  • We can take \(H = 9\) as the longest possible game length.
  • \(P(s, a)\) for a nonterminal state \(s\) is simply the board with the symbol and square specified by \(a\) marked into \(s\). Otherwise, if \(s\) is a terminal state, i.e. it already has three symbols in a row, the state no longer changes.
  • \(r(s)\) at a terminal state is \(+1\) if there are three Xs in a row, \(-1\) if there are three Os in a row, and \(0\) otherwise.

Our notation may remind you of Chapter 1. Given that these games also involve a sequence of states and actions, can we formulate them as finite-horizon MDPs? The two settings are not exactly analogous, since in MDPs we only consider a single policy, while these games involve two distinct players with opposite objectives. Since we want to analyze the behavior of both players at the same time, describing such a game as an MDP is more trouble than it’s worth.

Code
class Player(IntEnum):
    EMPTY = 0
    X = 1
    O = 2


if False:
    class TicTacToeEnv(gym.Env):
        metadata = {"render.modes": ["human"]}

        def __init__(self):
            super().__init__()
            self.action_space = spaces.Discrete(9)
            self.observation_space = spaces.Box(
                low=0, high=2, shape=(3, 3), dtype=jnp.int32
            )
            self.board = None
            self.current_player = None
            self.done = None

        def reset(self, seed=None, options=None):
            super().reset(seed=seed)
            self.board = jnp.zeros((3, 3), dtype=jnp.int32)
            self.current_player = Player.X
            self.done = False
            return self.board, {}

        def step(self, action: jnp.int32) -> Int[Array, "3 3"]:
            """Take the action a in state s."""
            if self.done:
                raise ValueError("The game is already over. Call `env.reset()` to reset the environment.")
            
            row, col = divmod(action, 3)
            if self.board[row, col] != Player.EMPTY:
                return self.board, -10
            return s.at[row, col].set(player)

        @staticmethod
        def is_terminal(s: Int[Array, "3 3"]):
            """Check if the game is over."""
            return is_winner(s, Player.X) or is_winner(s, Player.O) or jnp.all(s == Player.EMPTY)

        @staticmethod
        def is_winner(board: Int[Array, "3 3"], player: Player):
            """Check if the given player has won."""
            return any(
                jnp.all(board[i, :] == player) or
                jnp.all(board[:, i] == player)
                for i in range(3)
            ) or jnp.all(jnp.diag(board) == player) or jnp.all(jnp.diag(jnp.fliplr(board)) == player)

        @staticmethod
        def show(s: Int[Array, "3 3"]):
            """Print the board."""
            for row in range(3):
                print(" | ".join(" XO"[s[row, col]] for col in range(3)))
                if row < 2:
                    print("-" * 5)

8.5 Monte Carlo Tree Search

The task of evaluating actions in a complex environment might seem familiar. We’ve encountered this problem before in both the Chapter 3 setting and the Chapter 1 setting. Now we’ll see how to combine concepts from these to form a more general and efficient tree search heuristic called Monte Carlo Tree Search (MCTS).

When a problem is intractable to solve exactly, we often turn to approximate algorithms that sacrifice some accuracy in exchange for computational efficiency. MCTS also improves on alpha-beta search in this sense. As the name suggests, MCTS uses Monte Carlo simulation, that is, collecting random samples and computing the sample statistics, in order to approximate the value of each action.

As before, we imagine a complete game tree in which each path represents an entire game. The goal of MCTS is to assign values to only the game states that are relevant to the current game; We gradually expand the tree at each move. For comparison, in alpha-beta search, the entire tree only needs to be solved once, and from then on, choosing an action is as simple as taking a maximum over the previously computed values.

The crux of MCTS is approximating the win probability of a state by a sample probability. In practice, MCTS is used for games with binary outcomes where \(r(s) \in \{ +1, -1 \}\), and so this is equivalent to approximating the final game score. To approximate the win probability from state \(s\), MCTS samples random games starting in \(s\) and computes the sample proportion of those that the player wins.

Note that, for a given state \(s\), choosing the best action \(a\) can be framed as a Chapter 3 problem, where each action corresponds to an arm, and the reward distribution of arm \(k\) is the distribution of the game score over random games after choosing that arm. The most commonly used bandit algorithm in practice for MCTS is the Section 3.6 algorithm.

Remark 8.1 (Summary of UCB). Let us quickly review the UCB bandit algorithm. For each arm \(k\), we track the sample mean \[ \hat \mu^k_t = \frac{1}{N_t^k} \sum_{\tau=0}^{t-1} \ind{a_\tau = k} r_\tau \] of all rewards from that arm up to time \(t\). Then we construct a confidence interval \[ C_t^k = [\hat \mu^k_t - B_t^k, \hat \mu^k_t + B_t^k], \] where \(B_t^k = \sqrt{\frac{\ln(2 t / \delta)}{2 N_t^k}}\) is given by Hoeffding’s inequality, so that with probability \(\delta\) (some fixed parameter we choose), the true mean \(\mu^k\) lies within \(C_t^k\). Note that \(B_t^k\) scales like \(\sqrt{1/N^k_t}\), i.e. the more we have visited that arm, the more confident we get about it, and the narrower the confidence interval.

To select an arm, we pick the arm with the highest upper confidence bound.

This means that, for each edge (corresponding to a state-action pair \((s, a)\)) in the game tree, we keep track of the statistics required to compute its UCB:

  • How many times it has been “visited” (\(N_t^{s, a}\))
  • How many of those visits resulted in victory (\(\sum_{\tau=0}^{t-1} \ind{(s_\tau, a_\tau) = (s, a)} r_\tau\)). Let us call this latter value \(W^{s, a}_t\) (for number of “wins”).

What does \(t\) refer to in the above expressions? Recall \(t\) refers to the number of time steps elapsed in the bandit environment. As mentioned above, each state \(s\) corresponds to its own bandit environment, and so \(t\) refers to \(N^s\), that is, how many actions have been taken from state \(s\). This term, \(N^s\), gets incremented as the algorithm runs; for simplicity, we won’t introduce another index to track how it changes.

Definition 8.2 (Monte Carlo tree search algorithm) Inputs: - \(T\), the number of iterations per move - \(\pi_{\text{rollout}}\), the rollout policy for randomly sampling games - \(c\), a positive value that encourages exploration

To choose a single move starting at state \(s_{\text{start}}\), MCTS first tries to estimate the UCB values for each of the possible actions \(\mathcal{A}(s_\text{start})\), and then chooses the best one. To estimate the UCB values, it repeats the following four steps \(T\) times:

  1. Selection: We start at \(s = s_{\text{start}}\). Let \(\tau\) be an empty list that we will use to track states and actions.
    • Until \(s\) has at least one action that hasn’t been taken:
      • Choose \(a \gets \arg\max_k \text{UCB}^{s, k}\), where \[ \text{UCB}^{s, a} = \frac{W^{s, a}}{N^{s, a}} + c \sqrt{\frac{\ln N^s}{N^{s, a}}} \tag{8.1}\]
      • Append \((s, a)\) to \(\tau\)
      • Set \(s \gets P(s, a)\)
  2. Expansion: Let \(s_\text{new}\) denote the final state in \(\tau\) (that has at least one action that hasn’t been taken). Choose one of these unexplored actions from \(s_\text{new}\). Call it \(a_{\text{new}}\). Add it to \(\tau\).
  3. Simulation: Simulate a complete game episode by starting with the action \(a_{\text{new}}\) and then playing according to \(\pi_\text{rollout}\). This results in the outcome \(r \in \{ +1, -1 \}\).
  4. Backup: For each \((s, a) \in \tau\):
    • Set \(N^{s, a} \gets N^{s, a} + 1\)
    • \(W^{s, a} \gets W^{s, a} + r\)
    • Set \(N^s \gets N^s + 1\)

After \(T\) repeats of the above, we return the action with the highest UCB value Equation 8.1. Then play continues.

Between turns, we can keep the subtree whose statistics we have visited so far. However, the rest of the tree for the actions we did not end up taking gets discarded.

The application which brought the MCTS algorithm to fame was DeepMind’s AlphaGo Silver et al. (2016). Since then, it has been used in numerous applications ranging from games to automated theorem proving.

How accurate is this Monte Carlo estimation? It depends heavily on the rollout policy \(\pi_\text{rollout}\). If the distribution \(\pi_\text{rollout}\) induces over games is very different from the distribution seen during real gameplay, we might end up with a poor value approximation.

8.5.1 Incorporating value functions and policies

To remedy this, we might make use of a value function \(v : \mathcal{S} \to \mathbb{R}\) that more efficiently approximates the value of a state. Then, we can replace the simulation step of Definition 8.2 with evaluating \(r = v(s-\text{next})\), where \(s-\text{next} = P(s-\text{new}, a-\text{new})\).

We might also make use of a “guiding” policy \(\pi_\text{guide} : \mathcal{S} \to \triangle(\mathcal{A})\) that provides “intuition” as to which actions are more valuable in a given state. We can scale the exploration term of Equation 8.1 according to the policy’s outputs.

Putting these together, we can describe an updated version of MCTS that makes use of these value functions and policy:

Definition 8.3 (Monte Carlo tree search with policy and value functions) Inputs: - \(T\), the number of iterations per move - \(v\), a value function that evaluates how good a state is - \(\pi_\text{guide}\), a guiding policy that encourages certain actions - \(c\), a positive value that encourages exploration

To select a move in state \(s_\text{start}\), we repeat the following four steps \(T\) times:

  1. Selection: We start at \(s = s_{\text{start}}\). Let \(\tau\) be an empty list that we will use to track states and actions.
    • Until \(s\) has at least one action that hasn’t been taken:
      • Choose \(a \gets \arg\max_k \text{UCB}^{s, k}\), where \[ \text{UCB}^{s, a} = \frac{W^{s, a}}{N^s} + c \cdot \pi_\text{guide}(a \mid s) \sqrt{\frac{\ln N^s}{N^{s, a}}} \tag{8.2}\]
      • Append \((s, a)\) to \(\tau\)
      • Set \(s \gets P(s, a)\)
  2. Expansion: Let \(s_\text{new}\) denote the final state in \(\tau\) (that has at least one action that hasn’t been taken). Choose one of these unexplored actions from \(s_\text{new}\). Call it \(a_{\text{new}}\). Add it to \(\tau\).
  3. Simulation: Let \(s_\text{next} = P(s_\text{new}, a_\text{new})\). Evaluate \(r = v(s_\text{next})\). This approximates the value of the game after taking the action \(a_\text{new}\).
  4. Backup: For each \((s, a) \in \tau\):
    • \(N^{s, a} \gets N^{s, a} + 1\)
    • \(W^{s, a} \gets W^{s, a} + r\)
    • \(N^s \gets N^s + 1\)

We finally return the action with the highest UCB value Equation 8.2. Then play continues. As before, we can reuse the tree across timesteps.

Code
class EdgeStatistics(NamedTuple):
    wins: int = 0
    visits: int = 0

class MCTSTree:
    """A representation of the search tree.

    Maps each state-action pair to its number of wins and the number of visits.
    """

    edges: dict[tuple["State", "Action"], EdgeStatistics]

def mcts_iter(tree, s_init):
    s = s_init
    # while all((s, a) in tree for a in env.action_state(s)):

How do we actually compute a useful \(\pi_\text{guide}\) and \(v\)? If we have some existing dataset of trajectories, we could use Chapter 7 (that is, imitation learning) to generate a policy \(\pi_\text{guide}\) via behavioral cloning and learn \(v\) by regressing the game outcomes onto states. Then, plugging these into Definition 8.3 results in a stronger policy by using tree search to “think ahead”.

But we don’t have to stop at just one improvement step; we could iterate this process via self-play.

8.5.2 Self-play

Recall the Section 1.3.7.2 algorithm from the Chapter 1. Policy iteration alternates between policy evaluation (taking \(\pi\) and computing \(V^\pi\)) and policy improvement (setting \(\pi\) to be greedy with respect to \(V^\pi\)). Above, we saw how MCTS can be thought of as a “policy improvement” operation: for a given policy \(\pi^0\), we can use it to guide MCTS, resulting in an algorithm that is itself a policy \(\pi^0_\text{MCTS}\) that maps from states to actions. Now, we can use Chapter 7 to obtain a new policy \(\pi^1\) that imitates \(\pi^0_\text{MCTS}\). We can now use \(\pi^1\) to guide MCTS, and repeat.

Definition 8.4 (MCTS with self-play) Input:

  • A parameterized policy class \(\pi_\theta : \mathcal{S} \to \triangle(\mathcal{A})\)
  • A parameterized value function class \(v_\lambda : \mathcal{S} \to \mathbb{R}\)
  • A number of trajectories \(M\) to generate
  • The initial parameters \(\theta^0, \lambda^0\)

For \(t = 0, \dots, T-1\):

  • Policy improvement: Let \(\pi^t_\text{MCTS}\) denote the policy obtained by Definition 8.3 with \(\pi-{\theta^t}\) and \(v-{\lambda^t}\). We use \(\pi^t-\text{MCTS}\) to play against itself \(M\) times. This generates \(M\) trajectories \(\tau-0, \dots, \tau-{M-1}\).
  • Policy evaluation: Use behavioral cloning to find a set of policy parameters \(\theta^{t+1}\) that mimic the behavior of \(\pi^t_\text{MCTS}\) and a set of value function parameters \(\lambda^{t+1}\) that approximate its value function. That is, \[ \begin{aligned} \theta^{t+1} &\gets \arg\min_\theta \sum_{m=0}^{M-1} \sum_{\hi=0}^{H-1} - \log \pi_\theta(a^m_\hi \mid s^m_\hi) \\ \lambda^{t+1} &\gets \arg\min_\lambda \sum_{m=0}^{M-1} \sum_{\hi=0}^{H-1} (v_\lambda(s^m_\hi) - R(\tau_m))^2 \end{aligned} \]

Note that in implementation, the policy and value are typically both returned by a single deep neural network, that is, with a single set of parameters, and the two loss functions are added together.

This algorithm was brought to fame by AlphaGo Zero Silver et al. (2017).

8.6 Summary

In this chapter, we explored tree search-based algorithms for deterministic, zero sum, fully observable two-player games. We began with Section 8.3, an algorithm for exactly solving the game value of every possible state. However, this is impossible to execute in practice, and so we must resort to various ways to reduce the number of states and actions that we must explore. Section 8.4 does this by -pruning- away states that we already know to be suboptimal, and Section 8.5 -approximates- the value of states instead of evaluating them exactly.

8.7 References

Chapter 5 of Russell and Norvig (2021) provides an excellent overview of search methods in games. The original AlphaGo paper Silver et al. (2016) was a groundbreaking application of these technologies. Silver et al. (2017) removed the imitation learning phase, learning from scratch. AlphaZero Silver et al. (2018) then extended to other games beyond Go, namely shogi and chess, also learning from scratch. In MuZero Schrittwieser et al. (2020), this was further extended by learning a model of the game dynamics.