# Implementing Sum-Product Message Passing

As part of reviewing the ML concepts I learned last year, I implemented the *sum-product message passing* we learned in our probabilistic modeling course.

Sum-product message passing (or belief propagation) is a method that can do inference on probabilistic graphical models. I’ll focus on the algorithm that can perform exact inference on tree-like factor graphs.

This post assumes knowledge of probabilistic graphical models (perhaps through the Coursera course) and maybe have heard of the sum-product message passing. I’ll freely use terms such as “factor graph” and “exact inference.”

## Sum-Product Message Passing

Sum-product message passing is an algorithm for efficiently applying the sum rules and product rules of probability to compute different distributions. For example, if a discrete probability distribution \( p(h_1, v_1, h_2, v_2) \) can be factorized as

I could compute marginals, for example, \( p(v_1) \), by multiplying the terms and summing over the other variables.

With marginals, one can compute distributions such as \( p(v_1) \) and \( p(v_1, v_2) \), which means that one can also compute terms like \( p(v_2 \mid v_1) \). Sum-product message passing provides an efficient method for computing these marginals.

This version will only work on discrete distributions. I’ll code it with directed graphical models in mind, though it should also work with undirected models with few changes.

## Part 1: (Digression) Representing probability distributions as numpy arrays

The sum-product message passing involves representing, summing, and multiplying discrete distributions. I think it’s pretty fun to try to implement this with numpy arrays; I gained more intuition about probability distributions and numpy.

A discrete conditional distribution \( p(v_1 \mid h_1) \) can be represented as an array with two axes, such as

\( h_1 \) = a | \( h_1 \) = b | \( h_1 \) = c | |
---|---|---|---|

\( v_1 \) = 0 | 0.4 | 0.8 | 0.9 |

\( v_1 \) = 1 | 0.6 | 0.2 | 0.1 |

Using an axis for each variable can generalize to more variables. For example, the 5-variable \( p(h_5 \mid h_4, h_3, h_2, h_1) \) could be represented by an array with five axes.

It’s useful to label axes with variable names. I’ll do this in my favorite way, a little `namedtuple`

! (It’s kind of like a janky version of the NamedTensor.)

### Checking that a numpy array is a valid discrete distribution

It’s easy to accidentally swap axes when creating numpy arrays representing distributions. I’ll also write code to verify they are valid distributions.

To check that a multidimensional array is a *joint* distribution, the entire array should sum to one.

To check that a 2D array is a *conditional* distribution, when all of the right-hand-side variables have been assigned, such as \( p(v_1 \mid h_1 = a) \), the resulting vector represents a distribution. The vector should have the length of the number of states of \( v_1 \) and should sum to one. Computing this in numpy involves summing along the axis corresponding to the \( v_1 \) variable.

To generalize conditional distribution arrays to the multi-dimensional example, again, when all of the right-hand-side variables have been assigned, such as \( p(h_5 \mid h_4=a, h_3=b, h_2=a, h_1=a) \), the resulting vector represents a distribution. The vector should have a length which is the number of states of \( h_1 \) and should sum to one.

### Multiplying distributions

In sum-product message passing, I also need to compute the product of distributions, such as \( p(h_2 \mid h_1)p(h_1) \).

In this case, I’ll only need to multiply a multidimensional array by a 1D array and occasionally a scalar. The way I ended up implementing this was to align the axis of the 1D array with its corresponding axis from the other distribution. Then I tile the 1D array to be the size of \( p(h_2 \mid h_1) \). This gives me the joint distribution \( p(h_1, h_2) \).

## Part 2: Factor Graphs

Factor graphs are used to represent a distribution for sum-product message passing. One factor graph that represents \( p(h_1, h_2, v_1, v_2) \) is

Factors, such as \( p(h_1) \), are represented by black squares and represent a factor (or function, such as a probability distribution.) Variables, such as \( h_1 \), are represented by white circles. Variables only neighbor factors, and factors only neighbor variables.

In code,

- There are two classes in the graph: Variable and Factor. Both classes have a string representing the name and a list of neighbors.
- A Variable can only have Factors in its list of neighbors. A Factor can only have Variables.
- To represent the probability distribution, Factors also have a field for data.

## Part 3: Parsing distributions into graphs

Defining a graph can be a little verbose. I can hack together a parser for probability distributions that can interpret a string like `p(h1)p(h2∣h1)p(v1∣h1)p(v2∣h2)`

as a factor graph for me.

(This is pretty fragile and not user-friendly. For example, be sure to use `|`

character rather than the indistinguishable `∣`

character!)

```
([Factor(p(h1), [h1]),
Factor(p(h2|h1), [h2, h1]),
Factor(p(v1|h1), [v1, h1]),
Factor(p(v2|h2), [v2, h2])],
{'h1': Variable(h1, [p(h1), p(h2|h1), p(v1|h1)]),
'h2': Variable(h2, [p(h2|h1), p(v2|h2)]),
'v1': Variable(v1, [p(v1|h1)]),
'v2': Variable(v2, [p(v2|h2)])})
```

## Part 4: Adding distributions to the graph

Before I can run the algorithm, I need to associate LabeledArrays with each Factor. At this point, I’ll create a class to hold onto the Variables and Factors.

While I’m here, I can do a few checks to make sure the provided data matches the graph.

I can now try to add distributions to a graph.

## Part 5: Belief Propagation

We made it! Now we can implement sum-product message passing.

Sum-product message passing will compute values (“messages”) for every edge in the factor graph.

The algorithm will compute a message from the Factor \( f \) to the Variable \( x \), notated as \( \mu_{f \to x}(x) \). It will also compute the value from Variable \( x \) to the Factor \( f \), \( \mu_{x \to f}(x) \). As is common in graph algorithms, these are defined recursively.

(I’m using the equations as given in Barber p84.)

### Variable-to-Factor Message

The variable-to-factor message is given by:

where \( ne(x) \) are the neighbors of \( x \).

### Factor-to-Variable Message

The variable-to-factor message is given by

In the case of probabilities, \( \phi_f(\chi_f) \) is the probability distribution associated with the factor, and \( \sum_{\chi_f \setminus x} \) sums over all variables except \( x \).

### Marginal

The marginal of a variable \( x \) is given by

## Adding to PGM

A source of message passing’s efficiency is that messages from one computation can be reused by other computations. I’ll create an object to store `Messages`

.

```
array([0.23, 0.77])
```

```
{('p(h1)', 'h1'): array([0.2, 0.8]),
('v1', 'p(v1|h1)'): 1.0,
('p(v1|h1)', 'h1'): array([1., 1.]),
('h1', 'p(h2|h1)'): array([0.2, 0.8]),
('p(h2|h1)', 'h2'): array([0.26, 0.74]),
('h2', 'p(v2|h2)'): array([0.26, 0.74]),
('p(v2|h2)', 'v2'): array([0.23, 0.77])}
```

```
array([0.2, 0.8])
```

#### Example from book

Example 5.1 on p79 of Barber has a numerical example. I can make sure I get the same values (`[0.5746, 0.318 , 0.1074]`

).