KL Divergence
Machine learning involves approximating intractable probability distributions. One approach to approximating is to find a distribution that minimizes the KL Divergence with the target distribution. For example, the approximating distributions could be normal distributions with different means and variances.
When KL Divergence is introduced in the context of machine learning, one point is that KL Divergence \( KL(P \mid\mid Q) \) will select a different distribution than \( KL(Q \mid\mid P) \). This blog post explores this by telling an optimizer (TensorFlow) to minimize the two KL Divergences.
KL Divergence equation for discrete distributions
Wikipedia gives the KL Divergence for discrete distributions as
If \( P_i \) = 0, then the \( i^{th} \) term is 0. \( KL \) is only defined if when \( Q_i = 0 \), then \( P_i = 0 \).
For example, we can have \( P \) be the distribution we’re trying approximate with \( Q \). The KL Divergence will be big if \( Q_i \) is close to 0 where \( P_i \) is not close to 0. If \( P_i \) is close to 0, \( Q_i \) won’t affect the KL Divergence as much.
An example target distribution and two example approximate distributions
Let’s plot a few examples!
For this first example, I’ll make \( P \) based on the distribution \( \beta(2, 5) \). This is interesting because \( P_i \) is 0 outside of the domain of 0 to 1. I’ll use \( Q \)s that are based on a normal distribution, so \( Q \) is never 0. I highlight the area where \( P_i > 0 \).
Aside: Discrete vs Continuous
In order to make cool-looking graphs, I’m using discrete distributions that are based on continuous distributions, like the normal distribution. For example, below I start with 200 evenly-spaced numbers between -1 and 2. I compute the value of the PDF for those numbers. Then I normalize the vector so the 200 numbers add to 1 and it becomes a discrete distribution.
Computing KL Divergence
I can translate the formula to numpy
, then compute the KL Divergence between the two approximating distributions and the target distribution.
As expected, the KL Divergence is higher for the approximating distribution based on Norm(1, 0.2) than the distribution based on Norm(0.2, 0.15).
Q = Norm(1.0, 0.20) KL(P || Q) = 6.491177
Q = Norm(0.2, 0.15) KL(P || Q) = 0.236206
Aside: Verifying the implementation
scipy
’s entropy
computes KL Divergence when called with two parameters. I can verify my implementation produces similar results.
Aside: Interactive
Before I implement something that minimizes the divergence automatically, I can use ipywidgets to interactively try different distributions.
Multimodal Example
One point with KL Divergence is that finding a \( Q \) that minimizes \( KL(Q \mid\mid P) \) is different than finding a \( Q \) that minimizes \( KL(P \mid\mid Q) \). One way to illustrate the difference is to look at a multimodal distribution.
Minimizing KL Divergence
I can implement KL Divergence in TensorFlow and then use gradient descent to find an approximating distribution \( Q = Norm(\mu, \sigma^2) \) that minimizes the KL Divergence.
It’s also neat to plot how the distribution shifts as it improves!
Comparison of KL(Q || P) to KL(P || Q)
Finally, I can compare the \( Q \) that minimizes \( KL(Q \mid\mid P) \) to the one that minimizes \( KL(P \mid\mid Q) \).
See Also
- Chapter 10 of Bishop’s PRML talks about Variational Inference and has examples using bimodal distributions over two dimensions.
- Another example of dealing with difficult distributions is using Gibbs Sampling!
- KL Divergence is used in Variational Methods like that used in Variational Inference, which is similar to Expectation-Maximization.