How to teach your computer to play video games

How to teach your computer to play video games

Teaching your computer to play video games has all the makings of a great story:

  • it is ingenious
  • beautiful in its simplicity
  • and utterly surprising

I'll explain the main ideas behind deep Q learning, using as few big and scary nouns as possible, as we teach our computer to play the classic Atari game Breakout.

The paddle at the bottom is controlled by a model I trained for this blog post.

The code I wrote is available here. It's a simplified version of the already concise implementation from cleanrl.

Meeting the protagonist

Our story has a single humble hero: the QNetwork, a small convolutional neural network (CNN) that does a lot of heavy lifting.

A CNN is a type of neural network designed to operate on adjacent numbers that are likely related — such as the RGB values that make up a digital image.

A picture consists of 3 arrays of numbers (RGB values).
Image credit: Diane Rohrer

The QNetwork is tiny by modern standards, with only 1.69 million trainable parameters. For comparison, state-of-the-art language models like GPT-4 are rumored to have nearly 2 trillion!

And yet this small network can be trained to do something remarkably useful. It takes in a game image and outputs 4 numbers, one for each possible action:

  • do nothing
  • move left
  • move right
  • fire

These numbers are called q-values. Each one estimates the final score the agent will reach by the end of the game if it takes that action from the current state. So by always picking the action with the highest q-value, the agent aims to maximize its score.

But how do we train the QNetwork to output accurate q-values?

The answer lies in the Bellman equation:

Bellman equation
image from the fabulous RL course by HuggingFace!

To estimate the total reward our program will collect by the end of the game, the equation says to add up two things:

  • the reward it gets for taking a specific action at the current decision point (which moves us to some next state in the game)
  • the additional points it's likely to collect from that next state until the end of the game — discounted by a coefficient gamma, to account for the inherent uncertainty about the future

This looks very much like a recursion!

The Bellman equation also has a remarkable property. If you can carry out the two steps above, and you play the game always picking the action with the highest q-value, then you will be playing the game optimally.

Okay — so we know what our QNetwork needs to do. But how do we actually train it?

The catch is that we don't have the ground truth. We don't know in advance the reward our program will receive by taking a given action at timestep t and playing on to the finish. But we do have enough signal to go by.

Here's the trick. We define the QNetwork as predicting the sum of the rewards (scores) it's likely to earn from a given state until the end of the episode. At any decision point, we show the game screen to the QNetwork, perform the action it recommends, and collect a reward:

# current_state == state at game step t
predicted_score = QNetwork(current_state)

next_state, reward_at_t = environment.step(argmax(predicted_score))

ground_truth_score = reward_at_t + QNetwork(next_state)

Now we have a ground_truth_score — and even though it's only an estimate, we can still measure the difference between it and the predicted_score (our loss). We then tell the network to do better next time by updating its parameters via backpropagation.

And, surprisingly, this works! 🙂

A word of caution

This blog post is a broad overview. You can grasp the main idea without wading into every detail — but be warned that in actual code those details matter a great deal, and the technique relies on a number of hacks to make the formulation work.

Everything you've read here is conceptually correct; I've simply left some details out. For instance, the network doesn't see a single frame of the game. We show it 4 frames at a time, because you couldn't tell which direction anything is moving from a single frame. On top of that, some Atari frames are only partially rendered due to hardware limitations, so having several frames lets you take the maximum across them to recover a complete picture.

I left these details out on purpose, because dwelling on them obscures the most important thing — the core idea behind how deep Q networks work, in all their astonishing power and simplicity!

If you'd like to explore the method further, I highly recommend looking closely at what makes up the training set. That reveals another layer of why this approach trains so well.

But for now, let's just take a moment to marvel at the simplicity of the main idea, and the astounding results a small neural network can deliver: