As part of reviewing the ML concepts I learned last year, I implemented the sum-product message passing, or belief propagation, that we learned in our probabilistic modeling course.
Belief propagation (or sum-product message passing) is a method that can do inference on probabilistic graphical models. I’ll focus on the algorithm that can perform exact inference on tree-like factor graphs.
This post assumes knowledge of probabilistic graphical models (perhaps through the Coursera course) and maybe have heard of belief propagation. I’ll freely use terms such as “factor graph” and “exact inference.”
Belief Propagation
Belief propagation, or sum-product message passing, is an algorithm for efficiently applying the sum rules and product rules of probability to compute different distributions. For example, if a discrete probability distribution \( p(h_1, v_1, h_2, v_2) \) can be factorized as
I could compute marginals, for example, \( p(v_1) \), by multiplying the terms and summing over the other variables.
With marginals, one can compute distributions such as \( p(v_1) \) and \( p(v_1, v_2) \), which means that one can also compute terms like \( p(v_2 \mid v_1) \). Belief propagation provides an efficient method for computing these marginals.
This version will only work on discrete distributions. I’ll code it with directed graphical models in mind, though it should also work with undirected models with few changes.
Part 1: (Digression) Representing probability distributions as numpy arrays
The sum-product message passing involves representing, summing, and multiplying discrete distributions. I think it’s pretty fun to try to implement this with numpy arrays; I gained more intuition about probability distributions and numpy.
A discrete conditional distribution \( p(v_1 \mid h_1) \) can be represented as an array with two axes, such as
\( h_1 \) = a
\( h_1 \) = b
\( h_1 \) = c
\( v_1 \) = 0
0.4
0.8
0.9
\( v_1 \) = 1
0.6
0.2
0.1
Using an axis for each variable can generalize to more variables. For example, the 5-variable \( p(h_5 \mid h_4, h_3, h_2, h_1) \) could be represented by an array with five axes.
It’s useful to label axes with variable names. I’ll do this in my favorite way, a little namedtuple! (It’s kind of like a janky version of the NamedTensor.)
Checking that a numpy array is a valid discrete distribution
It’s easy to accidentally swap axes when creating numpy arrays representing distributions. I’ll also write code to verify they are valid distributions.
To check that a multidimensional array is a joint distribution, the entire array should sum to one.
To check that a 2D array is a conditional distribution, when all of the right-hand-side variables have been assigned, such as \( p(v_1 \mid h_1 = a) \), the resulting vector represents a distribution. The vector should have the length of the number of states of \( v_1 \) and should sum to one. Computing this in numpy involves summing along the axis corresponding to the \( v_1 \) variable.
To generalize conditional distribution arrays to the multi-dimensional example, again, when all of the right-hand-side variables have been assigned, such as \( p(h_5 \mid h_4=a, h_3=b, h_2=a, h_1=a) \), the resulting vector represents a distribution. The vector should have a length which is the number of states of \( h_1 \) and should sum to one.
Multiplying distributions
In belief propagation, I also need to compute the product of distributions, such as \( p(h_2 \mid h_1)p(h_1) \).
In this case, I’ll only need to multiply a multidimensional array by a 1D array and occasionally a scalar. The way I ended up implementing this was to align the axis of the 1D array with its corresponding axis from the other distribution. Then I tile the 1D array to be the size of \( p(h_2 \mid h_1) \). This gives me the joint distribution \( p(h_1, h_2) \).
Part 2: Factor Graphs
Factor graphs are used to represent a distribution for sum-product message passing.
One factor graph that represents \( p(h_1, h_2, v_1, v_2) \) is
Factors, such as \( p(h_1) \), are represented by black squares and represent a factor (or function, such as a probability distribution.) Variables, such as \( h_1 \), are represented by white circles. Variables only neighbor factors, and factors only neighbor variables.
In code,
There are two classes in the graph: Variable and Factor. Both classes have a string representing the name and a list of neighbors.
A Variable can only have Factors in its list of neighbors. A Factor can only have Variables.
To represent the probability distribution, Factors also have a field for data.
Part 3: Parsing distributions into graphs
Defining a graph can be a little verbose. I can hack together a parser for probability distributions that can interpret a string like p(h1)p(h2∣h1)p(v1∣h1)p(v2∣h2) as a factor graph for me.
(This is pretty fragile and not user-friendly. For example, be sure to use | character rather than the indistinguishable ∣ character!)
Before I can run the algorithm, I need to associate LabeledArrays with each Factor. At this point, I’ll create a class to hold onto the Variables and Factors.
While I’m here, I can do a few checks to make sure the provided data matches the graph.
I can now try to add distributions to a graph.
Part 5: Belief Propagation
We made it! Now we can implement sum-product message passing.
Sum-product message passing will compute values (“messages”) for every edge in the factor graph.
The algorithm will compute a message from the Factor \( f \) to the Variable \( x \), notated as \( \mu_{f \to x}(x) \). It will also compute the value from Variable \( x \) to the Factor \( f \), \( \mu_{x \to f}(x) \). As is common in graph algorithms, these are defined recursively.
(I’m using the equations as given in Barber p84.)
Variable-to-Factor Message
The variable-to-factor message is given by:
where \( ne(x) \) are the neighbors of \( x \).
Factor-to-Variable Message
The variable-to-factor message is given by
In the case of probabilities, \( \phi_f(\chi_f) \) is the probability distribution associated with the factor, and \( \sum_{\chi_f \setminus x} \) sums over all variables except \( x \).
Marginal
The marginal of a variable \( x \) is given by
Adding to PGM
A source of message passing’s efficiency is that messages from one computation can be reused by other computations. I’ll create an object to store Messages.