What is an intuitive explanation of the Expectation Maximization technique? [closed]

Note: the code behind this answer can be found here.


Suppose we have some data sampled from two different groups, red and blue:

enter image description here

Here, we can see which data point belongs to the red or blue group. This makes it easy to find the parameters that characterise each group. For example, the mean of the red group is around 3, the mean of the blue group is around 7 (and we could find the exact means if we wanted).

This is, generally speaking, known as maximum likelihood estimation. Given some data, we compute the value of a parameter (or parameters) that best explains that data.

Now imagine that we cannot see which value was sampled from which group. Everything looks purple to us:

enter image description here

Here we have the knowledge that there are two groups of values, but we don't know which group any particular value belongs to.

Can we still estimate the means for the red group and blue group that best fit this data?

Yes, often we can! Expectation Maximisation gives us a way to do it. The very general idea behind the algorithm is this:

  1. Start with an initial estimate of what each parameter might be.
  2. Compute the likelihood that each parameter produces the data point.
  3. Calculate weights for each data point indicating whether it is more red or more blue based on the likelihood of it being produced by a parameter. Combine the weights with the data (expectation).
  4. Compute a better estimate for the parameters using the weight-adjusted data (maximisation).
  5. Repeat steps 2 to 4 until the parameter estimate converges (the process stops producing a different estimate).

These steps need some further explanation, so I'll walk through the problem described above.

Example: estimating mean and standard deviation

I'll use Python in this example, but the code should be fairly easy to understand if you're not familiar with this language.

Suppose we have two groups, red and blue, with the values distributed as in the image above. Specifically, each group contains a value drawn from a normal distribution with the following parameters:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

Here is an image of these red and blue groups again (to save you from having to scroll up):

enter image description here

When we can see the colour of each point (i.e. which group it belongs to), it's very easy to estimate the mean and standard deviation for each each group. We just pass the red and blue values to the builtin functions in NumPy. For example:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

But what if we can't see the colours of the points? That is, instead of red or blue, every point has been coloured purple.

To try and recover the mean and standard deviation parameters for the red and blue groups, we can use Expectation Maximisation.

Our first step (step 1 above) is to guess at the parameter values for each group's mean and standard deviation. We don't have to guess intelligently; we can pick any numbers we like:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

These parameter estimates produce bell curves that look like this:

enter image description here

These are bad estimates. Both means (the vertical dotted lines) look far off any kind of "middle" for sensible groups of points, for instance. We want to improve these estimates.

The next step (step 2) is to compute the likelihood of each data point appearing under the current parameter guesses:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Here, we have simply put each data point into the probability density function for a normal distribution using our current guesses at the mean and standard deviation for red and blue. This tells us, for example, that with our current guesses the data point at 1.761 is much more likely to be red (0.189) than blue (0.00003).

For each data point, we can turn these two likelihood values into weights (step 3) so that they sum to 1 as follows:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

With our current estimates and our newly-computed weights, we can now compute new estimates for the mean and standard deviation of the red and blue groups (step 4).

We twice compute the mean and standard deviation using all data points, but with the different weightings: once for the red weights and once for the blue weights.

The key bit of intuition is that the greater the weight of a colour on a data point, the more the data point influences the next estimates for that colour's parameters. This has the effect of "pulling" the parameters in the right direction.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

We have new estimates for the parameters. To improve them again, we can jump back to step 2 and repeat the process. We do this until the estimates converge, or after some number of iterations have been performed (step 5).

For our data, the first five iterations of this process look like this (recent iterations have stronger appearance):

enter image description here

We see that the means are already converging on some values, and the shapes of the curves (governed by the standard deviation) are also becoming more stable.

If we continue for 20 iterations, we end up with the following:

enter image description here

The EM process has converged to the following values, which turn out to very close to the actual values (where we can see the colours - no hidden variables):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

In the code above you may have noticed that the new estimation for standard deviation was computed using the previous iteration's estimate for the mean. Ultimately it does not matter if we compute a new value for the mean first as we are just finding the (weighted) variance of values around some central point. We will still see the estimates for the parameters converge.


EM is an algorithm for maximizing a likelihood function when some of the variables in your model are unobserved (i.e. when you have latent variables).

You might fairly ask, if we're just trying to maximize a function, why don't we just use the existing machinery for maximizing a function. Well, if you try to maximize this by taking derivatives and setting them to zero, you find that in many cases the first-order conditions don't have a solution. There's a chicken-and-egg problem in that to solve for your model parameters you need to know the distribution of your unobserved data; but the distribution of your unobserved data is a function of your model parameters.

E-M tries to get around this by iteratively guessing a distribution for the unobserved data, then estimating the model parameters by maximizing something that is a lower bound on the actual likelihood function, and repeating until convergence:

The EM algorithm

Start with guess for values of your model parameters

E-step: For each datapoint that has missing values, use your model equation to solve for the distribution of the missing data given your current guess of the model parameters and given the observed data (note that you are solving for a distribution for each missing value, not for the expected value). Now that we have a distribution for each missing value, we can calculate the expectation of the likelihood function with respect to the unobserved variables. If our guess for the model parameter was correct, this expected likelihood will be the actual likelihood of our observed data; if the parameters were not correct, it will just be a lower bound.

M-step: Now that we've got an expected likelihood function with no unobserved variables in it, maximize the function as you would in the fully observed case, to get a new estimate of your model parameters.

Repeat until convergence.


Here is a straight-forward recipe to understand the Expectation Maximisation algorithm:

1- Read this EM tutorial paper by Do and Batzoglou.

2- You may have question marks in your head, have a look at the explanations on this maths stack exchange page.

3- Look at this code that I wrote in Python that explains the example in the EM tutorial paper of item 1:

Warning : The code may be messy/suboptimal, since I am not a Python developer. But it does the job.

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Technically the term "EM" is a bit underspecified, but I assume you refer to the Gaussian Mixture Modelling cluster analysis technique, that is an instance of the general EM principle.

Actually, EM cluster analysis is not a classifier. I know that some people consider clustering to be "unsupervised classification", but actually cluster analysis is something quite different.

The key difference, and the big misunderstanding classification people always have with cluster analysis is that: in cluster analaysis, there is no "correct solution". It is a knowledge discovery method, it is actually meant to find something new! This makes evaluation very tricky. It is often evaluated using a known classification as reference, but that is not always appropriate: the classification you have may or may not reflect what is in the data.

Let me give you an example: you have a large data set of customers, including gender data. A method that splits this data set into "male" and "female" is optimal when you compare it with the existing classes. In a "prediction" way of thinking this is good, as for new users you could now predict their gender. In a "knowledge discovery" way of thinking this is actually bad, because you wanted to discover some new structure in the data. A method that would e.g. split the data into elderly people and kids however would score as worse as it can get with respect to the male/female class. However, that would be an excellent clustering result (if the age wasn't given).

Now back to EM. Essentially it assumes that your data is composed of multiple multivariate normal distributions (note that this is a very strong assumption, in particular when you fix the number of clusters!). It then tries to find a local optimal model for this by alternatingly improving the model and the object assignment to the model.

For best results in a classification context, choose the number of clusters larger than the number of classes, or even apply the clustering to single classes only (to find out whether there is some structure within the class!).

Say you want to train a classifier to tell apart "cars", "bikes" and "trucks". There is little use in assuming the data to consist of exactly 3 normal distributions. However, you may assume that there is more than one type of cars (and trucks and bikes). So instead of training a classifier for these three classes, you cluster cars, trucks and bikes into 10 clusters each (or maybe 10 cars, 3 trucks and 3 bikes, whatever), then train a classifier to tell apart these 30 classes, and then merge the class result back to the original classes. You may also discover that there is one cluster that is particularly hard to classify, for example Trikes. They're somewhat cars, and somewhat bikes. Or delivery trucks, that are more like oversized cars than trucks.


Other answers being good, i will try to provide another perspective and tackle the intuitive part of the question.

EM (Expectation-Maximization) algorithm is a variant of a class of iterative algorithms using duality

Excerpt (emphasis mine):

In mathematics, a duality, generally speaking, translates concepts, theorems or mathematical structures into other concepts, theorems or structures, in a one-to-one fashion, often (but not always) by means of an involution operation: if the dual of A is B, then the dual of B is A. Such involutions sometimes have fixed points, so that the dual of A is A itself

Usually a dual B of an object A is related to A in some way that preserves some symmetry or compatibility. For example AB = const

Examples of iterative algorithms, employing duality (in the previous sense) are:

  1. Euclidean algorithm for Greatest Common Divisor, and its variants
  2. Gram–Schmidt Vector Basis algorithm and variants
  3. Arithmetic Mean - Geometric Mean Inequality, and its variants
  4. Expectation-Maximization algorithm and its variants (see also here for an information-geometric view)
  5. (.. other similar algorithms..)

In a similar fashion, the EM algorithm can also be seen as two dual maximization steps:

..[EM] is seen as maximizing a joint function of the parameters and of the distribution over the unobserved variables.. The E-step maximizes this function with respect to the distribution over the unobserved variables; the M-step with respect to the parameters..

In an iterative algorithm using duality there is the explicit (or implicit) assumption of an equilibrium (or fixed) point of convergence (for EM this is proved using Jensen's inequality)

So the outline of such algorithms is:

  1. E-like step: Find best solution x with respect to given y being held constant.
  2. M-like step (dual): Find best solution y with respect to x (as computed in previous step) being held constant.
  3. Criterion of Termination/Convergence step: Repeat steps 1, 2 with the updated values of x,y until convergence (or specified number of iterations is reached)

Note that when such an algorithm converges to a (global) optimum, it has found a configuration which is best in both senses (i.e in both the x domain/parameters and the y domain/parameters). However the algorithm can just find a local optimum and not the global optimum.

i would say this is the intuitive description of the outline of the algorithm

For the statistical arguments and applications, other answers have given good explanations (check also references in this answer)