How to teach your computer to play video games
Teaching your computer to play video games has all the components of a sublime storyline:
- it is ingenious
- beautiful in its simplicity
- and utterly surprising
I will explain the main ideas of deep Q learning, using as few big and scary nouns as possible, as we teach our computer to play the classic Atari game Breakout.
The code I wrote for this blog post is available here. It's a simplified version of the already concise implementation from cleanrl.
The first act – meeting the protagonist
Our story has a single humble hero: the QNetwork
, a small convolutional neural network (CNN) that does a lot of heavy lifting.
A CNN is a type of neural network designed to operate on adjacent numbers that are likely related, such as the RGB values that make up a digital image.
The QNetwork
is tiny by modern standards, with only 1.69 million trainable parameters. For comparison, state-of-the-art language models like GPT-4 are rumored to have nearly 2 trillion parameters!
Despite its small size, the QNetwork
can be trained to do something extremely useful.
It takes in a game image and outputs 4 numbers, each corresponding to a possible action:
- do nothing
- move left
- move right
- or fire.
These output numbers, called q-values
, represent the estimated score the game-playing agent will achieve by the end of the game if it takes a specific action given the current game state.
By always selecting the action with the highest predicted q-value
, the agent can aim to maximize its final score.
But how do we train the QNetwork
to output accurate q-values
?
The answer lies in the Bellman equation:
The equation states the following:
To estimate the reward our computer program will achieve at the end of the game:
- find out the reward it will get for taking a specific action at the current decision point, this will transition us to some next state in the game
- find out how many points it is likely to additionally collect from the new state it has just reached till the end of the game (and discount the estimate by a coefficient
gamma
to account for some of the inherent uncertainty about the future)
This looks very much like a recursion!
Additionally, the Bellman Equation has an interesting property:
- if you can do the two steps above
- if you play the game and always pick the action with the highest
q-value
at all decision points you encounter
you will be optimally playing the game!
Okay, so we know what our QNetwork
needs to do, but how do we train it?
We don't have the ground truth (that is, we don't have the reward our program will receive by taking a given action at timestep t
and continuing to play the game till the finish).
But we do have enough signal to go by!
This is what we end up doing:
At any decision point, we show the game screen to our QNetwork
, we perform the action it recommends, and get a reward.
So if we define the QNetwork
as predicting the sum of the rewards (scores) it is likely to earn till the end of the episode when presented with some game state t
we get the following:
# current_state == state at game step t
predicted_score = QNetwork(current_state)
next_state, reward_at_t = environment.step(argmax(predicted_score))
ground_truth_score = reward_at_t + QNetwork(next_state)
And if we have the ground_truth_score
(even if it is only estimated) we can still calculate the difference between the predicted_score
and the ground_truth_score
(our loss)
and tell the network to do better next time by updating its parameters via backpropagation.
And, surprisingly, this works! 🙂
A word of caution
This blog post provides a broad overview of the topic.
You can understand the main concept without delving into all the fine details.
However, in actual code, these details are crucial for the formulation to work, and it can be challenging to grasp the core idea without examining the various hacks required by this technique.
Everything you read here is conceptually correct, although some details have been omitted.
For instance, the neural network doesn't see a single frame from the game; instead, we show it 4 frames at a time, as discerning the direction of movement would be impossible with just a single frame.
Additionally, on Atari, some frames may be partially rendered due to hardware limitations. Having more than one frame can be helpful, as you can take the maximum value of those frames to get a complete rendition.
However, I didn't want to focus on such details in this post, as it obscures the most significant aspect—the core concept behind how deep Q networks work in all their astonishing prowess and simplicity!
If you would like to further explore this method, I highly recommend examining what comprises the training set, as this will reveal another layer of what makes this approach train so well.
But for now, let's take a moment to marvel at the simplicity behind the main idea and the astounding results that a small neural network can deliver: