Understanding Policy Gradient - a fundamental idea in RL
How do you begin to learn Reinforcement Learning?
My preferred approach is to study code.
Reading and analyzing code can help disambiguate many ideas and concepts in papers or blog posts that can be hard to understand otherwise.
Crafting an optimal policy by learning value functions is a very straightforward idea. A good example is Deep-Q Learning, which I covered in an earlier blog post.
But policy gradient methods are a completely different beast!
They are nothing like anything someone living in the supervised or semi-supervised deep learning world might have experienced before.
If you have a background in probability and can follow math notation, I highly recommend the Intro to Policy Optimization from Spinning Up by OpenAI.
On top of a background in probability, the other prerequisites include:
- understanding how log odds (logits) passed through a Sigmoid or a Softmax become probabilities
- understanding the negative log-likelihood loss
- having a working understanding of the chain rule
In this blog post, I plan to explain the main ideas behind policy gradient without relying on any mathematical heavy lifting.
This should allow anyone to jump into policy gradient methods from the get-go and make it easier to patch up the missing theoretical pieces if it is something you might desire to do.
On top of that, it is valuable to interact with ideas you are learning on different levels of complexity and rigor. Even if you can follow the math with ease, there is still value in understanding the why behind some of the operations.
In understanding the bigger picture of the method.
So let's get started.
One equation to rule them all
The below equation is at the heart of policy gradient methods.
\(J\) is the expected return of using policy \(\pi\) parametrized by weights \(\theta\).
To make this more concrete, \(\pi\) can be a neural network that takes in observations (the state of the environment) and outputs a distribution over possible actions.
We sample from that distribution, which is how we learn which action to take.
Now, what we want to do down the road is maximize the expected returns. We do so by modifying our policy's parameters \(\theta\).
So far, so good.
The right-hand side of the equation gives us the recipe to calculate \(J\) (the expected return).
It is an expectation over trajectories sampled from policy \(\theta\) with rewards assigned to each trajectory by the reward function \(R\).
What a mouthful!
But the idea is simple:
- collect a couple of trajectories by following policy \(\pi\)
- for each trajectory, generate a reward using function \(R\)
Once you have enough trajectories, take a mean of their rewards, and ... you will get the expected returned value!
There is one crucial aspect to keep in mind here.
When the variance of the collected data is higher, a larger sample size is needed to estimate the mean of a population accurately.
So the higher the variance of what we put inside the square brackets on the right (\(R(\tau)\)), the more of it – in this case, the trajectories and their associated rewards – we will need to collect to obtain a good estimation.
Otherwise, we risk our policy not training that well.
Ok, great. We now have a way of estimating expected returns.
To maximize them, we do the most straightforward thing in the universe—we calculate the policy's gradient and take a step in the positive direction!
We perform gradient ascent.
But how do you calculate the gradient of a policy?
Derivation of simple policy gradient
This is where the scary-looking math shows up:
But again, the idea is simple.
We can bypass the derivation and skip directly to the bottom line.
The \(\mathop{\mathbb{E}}\) again tells us that we will be dealing with a mean of values collected over a sufficient number of trajectories.
Inside the square brackets, we take a logarithm of the probabilities output by our model and weigh it by rewards.
The plan is to maximize the log probs
across the board, but not equally!
The higher the reward associated with the log probs
, the more we want them to increase!
As a result, over a batch of trajectories, the gradients will tell us how to modify parameters \(\theta\) so that the log probs
of actions associated with higher rewards increase!
It doesn't matter much that log probs
are always negative. The closer they are to 0, the closer the probability is to 1.
And we want to bring the probability of actions leading to higher rewards as close to 1 as possible!
We need to account for only one minor detail in the code.
We use the PyTorch optimizer, which minimizes the value for which we calculated the gradients.
That is because it deals with cost functions, which usually output higher values the greater the discrepancy between predictions and labels. We want to make that discrepancy as small as possible.
Here, this concept doesn't apply at all.
There are no labels, and we only manipulate the parameters of our policy to increase the score.
Still, as we use an off-the-shelf optimizer, we add a minus sign in front of the results of calculating the expected return value.
By minimizing the negative of \(J\), we increase its non-modified output!
Improving our algorithm
The more we can reduce the variance of the collected data, the better and quicker the Vanilla Policy Gradient will train.
One useful idea is changing how we assign rewards to each observation-action pair.
In the original formulation, we took the reward at the end of a trajectory and assigned it uniformly to all the actions.
But if you think about it, why should we reward or penalize the policy for actions it took in the past?
After all, the only thing our action at a given timestep can influence is the future!
This is where the concept of reward-to-go comes into play.
Instead of using the sum of rewards for all the actions, we calculate the sum of rewards from the given timestep forward. For each action, we look at the rewards we collect until the end of the episode and use their sum as the reward for that specific action.
The expected grad-log-prob lemma is another useful result that we can leverage here.
The consequence of the lemma is that for any function that only depends on state, we get a gradient of 0.
We can add or subtract functions of that form to our calculation of the Vanilla Policy Gradient at will, however many we would like.
This is a beneficial result because we can use such functions (called baselines) to reduce the variance of our estimated returns!
For instance, one choice of a baseline could be the on-policy value function:
We could subtract the estimated value at every step and only use the obtained data.
This demonstrably reduces the variance in the sample estimate for the policy gradient in tests.
Intuitively, we could read this as follows:
We don't really care about the reward associated with a given action, what we want to know is whether the chosen action leads to a better outcome than what we might expect to achieve on avarage starting with the current situation.
A related idea is that of an advantage function.
Advantage functions come in different shapes and sizes, and you can read more about them in the following paper: High-Dimensional Continuous Control Using Generalized Advantage Estimation.
The interesting aspect of advantage functions is that if a given action is better than what we could expect from our policy, the modified reward term will be positive.
If the given action would lead to a worse outcome, the modified reward term is negative.
And so we either increase or decrease the probability of a given action based on how it compares to the performance of our policy on average!
Summary
When do you understand something?
Did Newton understand gravity? Or was it only with Einstein and the general relativity that we finally learned what gravity is?
I used to believe that I had to understand every little bit about anything I wanted to use.
And that is a very insecure and limiting way to live your life.
I want math to make sense to me, but understanding the details without grasping the bigger picture is useless.
Plus, there is no shame in stopping at a level of understanding that suits your current situation and what you want to achieve.
Sometimes going down a rabbit hole is the most rewarding, the most enjoyable thing you can do, sometimes it is just a giant waste of time.
I wrote this blog post so that you and I can decide where to stop to pursue our current objective to be most effective.