In this posting, we will mimic Dirichlet process based Gaussian mixture model using Stan. Since Stan doesn’t provide the Dirichlet Process prior, we mimic it via finite mixture model. The stick-breaking process is achived inside of the stan code, and in my personal view, this method is faily similar to one that we can find in PyMC3 manual link.

First, I will generate data which consists of 3 mixture of normal distributions. And then, we will fit the model using Stan. Here, we will just check the capability, and the detailed analysis of results are omitted.

Data generation

rm(list = ls())
library(mixtools)
library(ggplot2)
library(tidyverse)
library(magrittr)
# Data generation code retrieved from
# http://www.jarad.me/615/2013/11/13/fitting-a-dirichlet-process-mixture

dat_generator <- function(truth) {
    set.seed(1)
    n = 500
    
    f = function(x) {
        out = numeric(length(x))
        for (i in 1:length(truth$pi)) out = out + truth$pi[i] * dnorm(x, truth$mu[i], 
            truth$sigma[i])
        out
    }
    y = rnormmix(n, truth$pi, truth$mu, truth$sigma)
    for (i in 1:length(truth$pi)) {
        assign(paste0("y", i), rnorm(n, truth$mu[i], truth$sigma[i]))
    }
    dat <- data_frame(y = y, y1 = y1, y2 = y2, y3 = y3)
}
truth = data.frame(pi = c(0.1, 0.5, 0.4), mu = c(-3, 0, 3), sigma = sqrt(c(0.5, 
    0.75, 1)))
dat <- dat_generator(truth)

The data(\(\textbf{y}\)) is mixture of \(\textbf{y}_1\), \(\textbf{y}_2\), and \(\textbf{y}_3\) where \(y_1\sim\text{N}(-3,0.5^2)\), \(y_2\sim\text{N}(0,0.75^2)\), and \(y_3\sim\text{N}(3,1^2)\), and the mixing rate is \(\boldsymbol{\pi}=(0.1,\,0.5,\,0.4)\)

ggplot(data = dat %>% gather(key, value), aes(value)) + geom_density(aes(color = key)) + 
    theme_bw() + xlab("y") + ggtitle("y is mixture of {y1,y2,y3}")

Dirichlet process with base distribution (\(G_0\sim\text{N}(\mu_0,\Sigma_0)\))

\[G=\sum_{k=1}^{\infty}\pi_k \delta_{\mu_k}=\mathcal{DP}(\alpha,\text{N}(\mu_0,\Sigma_0))\] But, to achieve Dirichlet process in Stan, we will assume the maximum number of mixture is fixed \[v_i\sim\text{Beta}(1,\alpha) \ (\text{where hyperprior can be assigned for } \alpha \text{, such as } \alpha \sim\text{Gamma}(a,b))\] \[\pi_c=v_c\prod_{i=1}^{c-1}(1-v_i)\:\text{where}\,\pi_1=v_1\]

And the mean of each cluster is given by \[\mu_i| \sim\text{N}(\mu_0,\sigma_0^2)\]

Then, the likelihood of Gaussian Mixture model will be \[\text{Pr}(y_i|\mu_0,\sigma_0,\alpha)=\sum_{\textbf{z}}\text{Pr}(z_i)\text{Pr}(y_i|z_i)=\sum_{c=1}^{C}\pi_c\text{N}(y_i|\mu_c,\sigma_c^2)\] The \(z_i\) is cluster assignment of \(y_i\), and \(\pi_c\) is \(\text{Pr}(z=c)\).

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

stan_model <- "
data{
  int<lower=0> C;//num of cludter
  int<lower=0> N;//data num
  real y[N];
}

parameters {
  real mu_cl[C]; //cluster mean
  real <lower=0,upper=1> v[C];
  real<lower=0> sigma_cl[C]; // error scale
  //real<lower=0> alpha; // hyper prior DP(alpha,base)
}

transformed parameters{
  simplex [C] pi;
  pi[1] = v[1];
  // stick-break process based on The BUGS book Chapter 11 (p.294)
  for(j in 2:(C-1)){
      pi[j]= v[j]*(1-v[j-1])*pi[j-1]/v[j-1]; 
  }
  pi[C]=1-sum(pi[1:(C-1)]); // to make a simplex.
}

model {
  real alpha = 1;
  real a=0.001;
  real b=0.001;
  real ps[C];
  sigma_cl ~ inv_gamma(a,b);
  mu_cl ~ normal(0,5);
  //alpha~gamma(6,1);
  v ~ beta(1,alpha);
  
  for(i in 1:N){
    for(c in 1:C){
      ps[c]=log(pi[c])+normal_lpdf(y[i]|mu_cl[c],sigma_cl[c]);
    }
    target += log_sum_exp(ps);
  }

}
"
y <- dat$y
C <- 10  # to ensure large enough
N <- length(y)
input_dat <- list(y = y, N = N, C = C)
# model_object<-stan_model(model_code=stan_model)
fit <- stan(model_code = stan_model, data = input_dat, iter = 1000, chains = 1)
results <- rstan::extract(fit)

Check the results. As you can see only 3 clusters are detected.

plot_dat_pi <- data.frame(results$pi) %>% as_data_frame() %>% set_names(sprintf("pi%02d", 
    1:10))

ggplot(data = plot_dat_pi %>% gather(key, value), aes(x = key, y = value)) + 
    geom_boxplot() + theme_bw()

Check the posterior of \(\mu_c\), and each posterior looks well-mixed.

library(gridExtra)
plot_mu_dat <- data.frame(results$mu_cl[, 1:3]) %>% as_data_frame() %>% set_names(sprintf("mu%d", 
    1:3))
plot_mu_dat %<>% mutate(xgrid = (1:length(plot_mu_dat$mu1)))
ggplot(plot_mu_dat %>% gather(key, value, mu1:mu3), aes(x = xgrid, y = value, 
    color = key)) + geom_point() + theme_bw()

Let’s summarize the results with truth.. (very roughly)

knitr::kable(truth %>% as_data_frame())
pi mu sigma
0.1 -3 0.7071068
0.5 0 0.8660254
0.4 3 1.0000000
mean_results <- data.frame(pi = colMeans(results$pi[, 1:3]), mu = colMeans(results$mu[, 
    1:3]), sigma = colMeans(results$sigma[, 1:3]))
knitr::kable(mean_results)
pi mu sigma
0.1172710 -2.8945143 0.9337762
0.3670522 2.9015420 1.1602458
0.5143328 0.0432029 0.9082477

I assigned kind of non-informative prior, but due to the simplicity the problem, it looks like the results are fairly in good agreement with the generated data (truth). Also, you can observer the order is converted, but this is fairly common issue in Mixture modeling. You can overcome this by using some methods. See this. But, most of the settings are based on something like “known order of \(\mu\)s”..