Appendix A — BART

A.1 Bayesian additive regression trees

The following slides give a brief introduction to BART (Chipman, George, and McCulloch 2012), a cornerstone modelling technique for the remainder of these notes. This is a natural place to discuss BART, as we have already discussed the basics of Bayesian statistics as well as machine learning in the notes. Further studies of adaptive basis techniques Bayesian non-parametric models were given in chapters 7 and chapter 8 respectively. These frameworks helped contextualize BART in the grander scheme of things. stochtree is an excellent software for implementing BART and it’s variants, which we will lean on heavily throughout the remainder of these notes.

On desktop, hover mouse over slides and press f to enter full-screen mode. Can also enter full-screen mode by clicking on menu in bottom left corner of slides.

A.2 The double summing of BART

Note, the code below is NOT RUN! If you run it, it will reproduce slide 54 from the BART slides above.

On 4 real datasets, see BART’s consistent performance over train/test splits.
suppressMessages(library(modeldata))
suppressMessages(library(agridat))
suppressMessages(library(caret))
RMSE <- function(pred, y){
  sqrt(mean((pred - y)^2))
}



data("cleveland.soil")
n = nrow(cleveland.soil)
index = sample(n, floor(0.8*n), replace=F)

X_cleveland = as.matrix(cbind(as.numeric(cleveland.soil$easting), as.numeric(cleveland.soil$northing)))
colnames(X_cleveland) = c('X1', 'X2')
y_cleveland = cleveland.soil$resistivity


n = nrow(X_cleveland)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- nrow(X_cleveland) - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_cleveland_test <- as.data.frame(X_cleveland[test_inds,])
X_cleveland_train <- as.data.frame(X_cleveland[train_inds,])
y_cleveland_train <- y_cleveland[train_inds]
y_cleveland_test <- y_cleveland[test_inds]

mean_forest_params = list(num_trees=100)
num_gfr <- 20
num_burnin <- 20
num_mcmc <- 400
fitz_cleveland = stochtree::bart(X_train = X_cleveland_train, y_train = y_cleveland_train,
                                 num_gfr = num_gfr,
                                 num_burnin=num_burnin,
                                 num_mcmc = num_mcmc,
                               mean_forest_params = mean_forest_params,
                               general_params = list(num_chains=1, verbose=F))
pred_cleveland_full = predict(fitz_cleveland, X_cleveland)
pred_cleveland_test = predict(fitz_cleveland, X_cleveland_test)
pred_cleveland_train = predict(fitz_cleveland, X_cleveland_train)

data('Chicago')
X = Chicago[, c('temp_min', 'temp', 'temp_max','temp_change',
                'dew', 'humidity','pressure', 'pressure_change',
                'wind', 'wind_max', 'gust', 'gust_max', 'percip',
                'percip_max', 'weather_rain', 'weather_snow', 'weather_cloud',
                'weather_storm', 'Blackhawks_Home',
                'Bulls_Home', 'Bears_Home', 'WhiteSox_Home', 'Cubs_Home', 'date')]

X$day = weekdays(as.Date(X$date))
X$year = format(as.Date(X$date, format="%d/%m/%Y"),"%Y")
X$month = format(as.Date(X$date, format="%d/%m/%Y"),"%m")
X_copy = X
X = X %>%
  dplyr::select(-date)

one_hot = dummyVars(" ~ .", data=X)

X_chicago = data.frame(X)##predict(one_hot, newdata=X))
cat_cols_chicago = colnames(X_chicago)[19:ncol(X_chicago)]

for (col in cat_cols_chicago) {
    X_chicago[,col] <- factor(X_chicago[,col], ordered = F)
 }
y_chicago = Chicago$ridership
n = nrow(X_chicago)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- nrow(X) - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_chicago_test <- as.data.frame(X_chicago[test_inds,])
X_chicago_train <- as.data.frame(X_chicago[train_inds,])
y_chicago_train <- y_chicago[train_inds]
y_chicago_test <- y_chicago[test_inds]
fitz_chicago = stochtree::bart(X_train = X_chicago_train, y_train = y_chicago_train,
                       num_gfr = num_gfr,
                       num_burnin=num_burnin,
                       num_mcmc = num_mcmc,
                       mean_forest_params = mean_forest_params,
                       general_params = list(num_chains=1, verbose=F))
pred_chicago_full = predict(fitz_chicago, X_chicago)
pred_chicago_test = predict(fitz_chicago, X_chicago_test)
pred_chicago_train = predict(fitz_chicago, X_chicago_train)

data('deliveries')
y_delivery = deliveries$time_to_delivery
X_delivery = deliveries %>%
  dplyr::select(-time_to_delivery)

one_hot = dummyVars(" ~ .", data=X_delivery)

X_delivery = data.frame(X_delivery)#predict(one_hot, newdata=X_delivery))
cat_cols_delivery = c(colnames(X_delivery)[2:8],colnames(X_delivery)[10:ncol(X_delivery)])
for (col in cat_cols_delivery) {
  X_delivery[,col] <- factor(X_delivery[,col], ordered = F)
}

n = nrow(X_delivery)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- nrow(X) - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_delivery_test <- as.data.frame(X_delivery[test_inds,])
X_delivery_train <- as.data.frame(X_delivery[train_inds,])
y_delivery_train <- y_delivery[train_inds]
y_delivery_test <- y_delivery[test_inds]



fitz_delivery = stochtree::bart(X_train = X_delivery_train,
                               y_train = y_delivery_train,
                               num_gfr = num_gfr,
                               num_burnin=num_burnin,
                               num_mcmc = num_mcmc,
                               mean_forest_params = mean_forest_params,
                               general_params = list(num_chains=1, verbose=F))
pred_delivery_full = predict(fitz_delivery, X_delivery)
pred_delivery_test = predict(fitz_delivery, X_delivery_test)
pred_delivery_train = predict(fitz_delivery, X_delivery_train)

petsdata = read.csv(paste0(here::here(), '/data/pets_data.csv'))
z = petsdata$genhealth
X_pets = petsdata%>%
  dplyr::select(-c(gotdog))

cat_cols_pets= c(colnames(X_pets)[2:7],colnames(X_pets)[9:ncol(X_pets)])
for (col in cat_cols_pets) {
  X_pets[,col] <- factor(X_pets[,col], ordered = F)
}



y_pets = petsdata$gotdog

# Split data into test and train sets
test_set_pct <- 0.2
n = nrow(X_pets)
n_test <- round(test_set_pct*nrow(X_pets))
n_train <- nrow(X_pets) - n_test
test_inds <- sort(sample(1:nrow(X_pets), n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X_pets[test_inds,])
X_train <- as.data.frame(X_pets[train_inds,])

y_test <- y_pets[test_inds]
y_train <- y_pets[train_inds]
num_gfr = 30
num_burnin = 0
num_mcmc = 100
num_samples = num_gfr + num_burnin + num_mcmc

bart_params = list(num_trees = 100,
                   alpha = 0.95, beta = 2,
                   min_samples_leaf = 5, sample_sigma2_leaf=T)
# Set number of iterations
num_gfr <- 50
num_burnin <- 0
num_mcmc <- 400
num_samples <- num_gfr + num_burnin + num_mcmc





fitz_pets = stochtree::bart(X_train = X_train,
                        y_train = y_train,
                        num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,

                        general_params = list(probit_outcome_model=T,
                                              sample_sigma2_global=F,
                                              num_chains=1, verbose=F),
                        mean_forest_params = bart_params)
pred_pets_full = predict(fitz_pets, X_pets)
pred_pets_test = predict(fitz_pets, X_test)
pred_pets_train = predict(fitz_pets, X_train)

cor_frame_chicago = sapply(1:400, function(m)RMSE(pred_chicago_full$y_hat[,m], y_chicago))
cor_frame_cleveland = sapply(1:400, function(m)RMSE(pred_cleveland_full$y_hat[,m], y_cleveland))

cor_frame_deliveries = sapply(1:400, function(m)RMSE(pred_delivery_full$y_hat[,m], y_delivery))

cor_frame_pets = sapply(1:400, function(m)pROC::auc(y_pets,pnorm(pred_pets_full$y_hat[,m]),  quiet=T))

cor_frame_chicago_train = sapply(1:400, function(m)RMSE(pred_chicago_train$y_hat[,m], y_chicago_train))
cor_frame_cleveland_train = sapply(1:400, function(m)RMSE(pred_cleveland_train$y_hat[,m], y_cleveland_train))

cor_frame_deliveries_train = sapply(1:400, function(m)RMSE(pred_delivery_train$y_hat[,m], y_delivery_train))

cor_frame_pets_train = sapply(1:400, function(m)pROC::auc(y_train,pnorm(pred_pets_train$y_hat[,m])))



cor_frame_chicago_test = sapply(1:400, function(m)RMSE(pred_chicago_test$y_hat[,m], y_chicago_test))
cor_frame_cleveland_test = sapply(1:400, function(m)RMSE(pred_cleveland_test$y_hat[,m], y_cleveland_test))

cor_frame_deliveries_test = sapply(1:400, function(m)RMSE(pred_delivery_test$y_hat[,m], y_delivery_test))

cor_frame_pets_test = sapply(1:400, function(m)pROC::auc(y_test,pnorm(pred_pets_test$y_hat[,m]), quiet=T))


df_error = data.frame(correlations = c(cor_frame_chicago_train, cor_frame_chicago_test,
                                       cor_frame_cleveland_train, cor_frame_cleveland_test,
                                 cor_frame_deliveries_train, cor_frame_deliveries_test,
                                 cor_frame_pets_train, cor_frame_pets_test),
                group = c(rep(c(rep('Train',400), rep('Test',400)),4)),
                data = c(rep('Chicago',400*2), rep('Cleveland', 400*2),
                         rep('Deliveries',400*2),
                         rep('Pets', 400*2)),
                index = rep(seq(from=1,to=400, length.out=400),8))


plot(y_chicago, rowMeans(pred_chicago_full$y_hat), bty='ln',xlab='Daily ridership', ylab='Predicted ridership',pch=16, col='#39204f')
abline(a=0, b=1, col='#7c1d1d', lwd=4)
plot(y_cleveland, rowMeans(pred_cleveland_full$y_hat),bty='ln', xlab='Soil resistivity', ylab='Predicted soil resistivity',pch=16, col='#39204f')
abline(a=0, b=1, col='#7c1d1d', lwd=4)
plot(y_delivery, rowMeans(pred_delivery_full$y_hat), bty='ln',xlab='Delivery time', ylab='Predicted delivery time',pch=16, col='#39204f')
abline(a=0, b=1, col='#7c1d1d', lwd=4)


plot_5 <- df_error %>%
  dplyr::filter(data %in% 'Chicago')%>%
  ggplot(aes(x=index, y=correlations, color=group))+

  geom_point()+
  scale_color_manual(values=c('#012296', '#55ad89'))+
  annotate(geom='text', x=250, y=2.06,
           label = 'Test set',
          color='#012296',
          size=8,
          fontface='bold',
           hjust=1, vjust=0 )+
 annotate(geom='text', x=150, y=1.85,
           label = 'Train set',
           color='#55ad89',

           size=8,
           fontface='bold',
           hjust=1, vjust=0)+

  ylab("RMSEs")+
  xlab('Posterior draw')+
  # xlim(c(0.5,1.0))+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")

plot_6 <- df_error %>%
  dplyr::filter(data %in% 'Cleveland')%>%
  ggplot(aes(x=index, y=correlations, color=group))+

  geom_point()+

  scale_color_manual(values=c('#012296', '#55ad89'))+
  annotate(geom='text', x=250, y=13,
           label = 'Test set',
           color='#012296',
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  annotate(geom='text', x=150, y=10.5,
           label = 'Train set',
           color='#55ad89',
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  ylab("RMSEs")+
  xlab('Posterior draw')+
  # xlim(c(0.5,1.0))+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")

plot_7 <- df_error %>%
  dplyr::filter(data %in% 'Deliveries')%>%
  ggplot(aes(x=index, y=correlations, color=group))+
  ylab("RMSEs")+
  xlab('Posterior draw')+
  geom_point()+
  ##geom_hline(yintercept=RMSE(rowMeans(pred_delivery_train$y_hat), y_delivery_train),
  #           lwd=1, lty=2,col='#55ad89')+
  #geom_hline(yintercept=RMSE(rowMeans(pred_delivery_test$y_hat), y_delivery_test),
  #           lwd=1, lty=2,col='#012296')+
  scale_color_manual(values=c('#012296', '#55ad89'))+
  annotate(geom='text', x=250, y=2.11,
           label = 'Test set',
           color='#012296', 
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  annotate(geom='text', x=120, y=1.8,
           label = 'Train set',
           color='#55ad89',
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  # xlim(c(0.5,1.0))+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")


plot_8 <- df_error %>%
  dplyr::filter(data %in% 'Pets')%>%
  ggplot(aes(x=index, y=correlations, color=group))+
  ylab("AUCs")+
  xlab('Posterior draw')+
  geom_point()+
  scale_color_manual(values=c('#012296', '#55ad89'))+
  # xlim(c(0.5,1.0))+
  ylim(c(0.67, 0.78))+
  annotate(geom='text', x=150, y=0.678,
           label = 'Test set',
           color='#012296',
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  annotate(geom='text', x=345, y=0.77,
           label = 'Train set',
           color='#55ad89', 
           size=8,
           fontface='bold',
           hjust=1, vjust=0)+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")


df = data.frame(correlations = c(cor_frame_chicago, cor_frame_cleveland,
                                 cor_frame_deliveries, cor_frame_pets),
                data = c(rep('Chicago',400), rep('Cleveland', 400),
                         rep('Deliveries',400),
                         rep('Pets', 400)))

plot_1 <- df %>%
  dplyr::filter(data %in% 'Chicago')%>%
  ggplot(aes(x=correlations))+
  xlab('Posterior RMSEs')+
 ## geom_dotplot(
  #                  fill='#073d6d', color='#f8f9fa', binwidth=.022, dotsize=1.4)  +

  geom_histogram(fill='#073d6d', color='#f8f9fa',bins=40,aes(y=after_stat(density)))+
  geom_vline(xintercept = RMSE(rowMeans(pred_chicago_full$y_hat), y_chicago), color='#d47c17', lwd=2)+
  #xlim(c(1.94,2.3))+
  annotate(geom='text', x=1.99, y=16,
           label = 'RMSE of posterior mean\nas point estimate.',
           color='#073d6d',
           size=5.5,
           fontface='bold',
           hjust=1, vjust=0)+
  annotate(geom='rect', xmin=1.86, xmax=2.0, ymin=14, ymax=22,
           fill=alpha('#073d6d', 0.05), color='#073d6d', linewidth=1)+
  annotate('segment', x=1.925, xend=RMSE(rowMeans(pred_chicago_full$y_hat), y_chicago)+0.004,
           y=14, yend=5, color='#d47c17',
           arrow=arrow(length=unit(0.025, 'npc'),
                       ends='last', type='closed')
  )+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")

plot_2 <- df %>%
  dplyr::filter(data %in% 'Cleveland')%>%
  ggplot(aes(x=correlations))+
  xlab('Poterior RMSEs')+
  geom_histogram(fill='#073d6d', color='#f8f9fa',bins=40, aes(y=after_stat(density)))+
  geom_vline(xintercept = RMSE(rowMeans(pred_cleveland_full$y_hat), y_cleveland), color='#d47c17', lwd=2)+

  theme_minimal(base_size=32)+
  theme(legend.position = "none")


plot_3 <- df %>%
  dplyr::filter(data %in% 'Deliveries')%>%
  ggplot(aes(x=correlations))+
  xlab('Posterior RMSEs')+
  geom_histogram(fill='#073d6d', color='#f8f9fa',bins=40, aes(y=after_stat(density)))+
  geom_vline(xintercept =  RMSE(rowMeans(pred_delivery_full$y_hat), y_delivery), color='#d47c17', lwd=2)+

  theme_minimal(base_size=32)+
  theme(legend.position = "none")




plot_4 <- df %>%
  dplyr::filter(data %in% 'Pets')%>%
  ggplot(aes(x=correlations))+
  xlab('Posterior AUCs')+
  geom_histogram(fill='#073d6d', color='#f8f9fa',bins=40, aes(y=after_stat(density)))+
  geom_vline(xintercept =  pROC::auc(y_pets, rowMeans(pnorm(pred_pets_full$y_hat)), quiet=T), color='#d47c17', lwd=2)+
  theme_minimal(base_size=32)+
  theme(legend.position = "none")


suppressMessages(library(gtExtras))
suppressMessages(library(reactablefmtr))
suppressMessages(library(purrr))
suppressMessages(library(tibble))
df_links = tibble(
  name = c("Feature Eng. & Sel.", "Agridat","ModelData", 'SBDecomp R'),
  link = c("https://modeldata.tidymodels.org/reference/Chicago.html",
           "https://cran.r-project.org/web/packages/agridat/agridat.pdf",
           "https://modeldata.tidymodels.org/reference/deliveries.html",
           "https://search.r-project.org/CRAN/refmans/SBdecomp/html/petsdata.html")
)
df_links = tibble(
  name = c("Chicago", "Cleveland","Delivery", 'Pets'),
  link = c("https://modeldata.tidymodels.org/reference/Chicago.html",
           "https://cran.r-project.org/web/packages/agridat/agridat.pdf",
           "https://modeldata.tidymodels.org/reference/deliveries.html",
           "https://search.r-project.org/CRAN/refmans/SBdecomp/html/petsdata.html")
)
link = sprintf('<a href = "%s">%s </a>', df_links$link, df_links$name)
link = map(link, gt::html)
tibble(Dataset = link,# c('Chicago', 'Cleveland', 'Deliveries', 'Pets'),
       Outcome = c('Train ridership', 'Soil resistivity', 'Food delivery time', 'Dog ownership'),
       `Sample size` = c(length(y_chicago),
                         length(y_cleveland),
                         length(y_delivery),
                         length(y_pets)),
       `# covariates` = c(ncol(X_chicago),
                       ncol(X_cleveland),
                       ncol(X_delivery),
                       ncol(X_pets)),

     # Source = link,
      Error =  I(list(plot_5,plot_6, plot_7, plot_8)),
     `Point estimate`= I(list(plot_1,plot_2, plot_3, plot_4)))%>%
  gt() %>%
  tab_options(
    table.background.color = "#f8f9fa" # A light gray, for example
  )%>%
  text_transform(
    locations = cells_body(columns = Error),
    fn = function(x) {
      lapply(x, function(plot_object) {
        ggplot_image(plot_object, height = px(205), aspect_ratio = 1.4) # Adjust height and aspect ratio as needed
      })
    }
  )%>%

text_transform(
  locations = cells_body(columns = `Point estimate`),
  fn = function(x) {
    lapply(x, function(plot_object) {
      ggplot_image(plot_object, height = px(205), aspect_ratio = 1.4) # Adjust height and aspect ratio as needed
    })
  }
)

On the right, the histogram indicates evaluation metrics for each posterior draw. The orange line is the evaluation metric of using the average across the posteriors as the BART point estimate.

A.3 Random effect BART

stochtree allows users to embed random effects alongside a BART term in a larger model. DESCRIBE. The data are from the agridat R package. Describe yields of barley at different Minnesota farming locations over different years as a function of weather conditions that year. Originally from (Immer and Henderson 1943), put together in (Wright 2013).

The data are natural for a random effect term because

With the data prepared, we now shift focus to modeling the data. The idea is to utilize Bayesian Additive Regression Trees (BART) (Chipman, George, and McCulloch 2012) with the stochtree software suite (Herren et al. 2025). Specifically, we use the stochtree implementation of XBART (He and Hahn 2023) to initialize promising BART forests.

As we have seen, BART is a carefully constructed model that is well equipped for prediction tasks. At its core, it represents an outcome \(y\mid \mathbf{x}\) through sum of regression trees (a “forest”). More formally:

\[ y = \underbrace{E(y\mid \mathbf{x})}_{\text{signal}}+\underbrace{\varepsilon}_{\text{noise}}=\underbrace{f(\mathbf{x})}_{\text{sum of trees}}+\varepsilon, \quad\varepsilon\sim N(0, \sigma^2) \tag{A.1}\]

Each tree fits what the previous tree misses as in boosting. The trees are built probabilistically in a Bayesian way. There is a prior probability for growing a tree, which consists of choosing one of the \(\mathbf{x}_m, m=1, \ldots, p\) variables to split on uniformly at random, selecting a random position within the chosen \(\mathbf{x}_{m}\) that determines where the data is split by the tree1, and a prior distribution that gives a probability to stop growing a tree of a certain size. When the tree is built, which passes the outcomes \(y\) into buckets based on how the inputs \(\mathbf{x}\) were split up the tree, prior distributions are placed on the value that is drawn within the data in each tree bucket. Finally, a prior distribution is placed on the error term \(\sigma^2\). All the priors were constructed such that tree structures are relatively simple, the values drawn in the trees are relatively small, and the error term is relatively conservative. The idea is that prior to seeing data we assume \(f(\mathbf{x})\) can be constructed by summing together relatively simple trees that each contribute minimally to the overall fit of \(y\), so as to avoid over-fitting. But, if the data allow, the priors can be over-ridden and more complicated \(f(\mathbf{x})\) can learned.

1 so the data fall into left and right buckets depending on which side of the split rule they are categorized into.

The tl;dr is that BART can learn complicated relationships between inputs and how those impact the outcome/response \(y\). At the same time, BART predictions generalize well into the future and are reliable.


With stochtree, we modify Equation A.1 as follows:

\[ y = f(\mathbf{x})+\gamma_j+\varepsilon \longrightarrow y\sim N\left(f(\mathbf{x})+\gamma_j, \sigma^2\right) \]

The modifications include placing a BART sum of trees model to have a random intercept \(\gamma_j\) for each farm location \(j=1, \ldots, K\), which is implemented in stochtree according to (Gelman et al. 2008). The random effect by farm location is included because of the natural grouping of replicates of each farm over the years that could be used to symbolize some sort of intrinsic features of the farm. On its own, having an input feature that can characterize difficult to measure intrinsic farm is valuable, and is a benefit of replicates. So simply including “site” as a column would likely help our model, at least from a predictive sense. Instead, we embed this information in a random intercept term for “site”. The random effects framework is preferable over just including a column indicator for multiple reasons. The random effect estimation process allows some degree of regularization via the pulling of group (“site) means towards one another. Additionally, the random intercept term is sampled, lending some”interpretability” to the variability of yield by town. Finally, observations by town now have errors that are correlated, so errors by town are correlated over time. This is a consequence of having a random intercept that is grouped by farm.

The final modification is to include an estimate of \(\hat{y}\) from a linear regression as a column into the \(\mathbf{x}\) covariate matrix. The thinking is that if there is a linear effect, trees can struggle to learn it without a lot of splits. This is a problem because BART is designed to avoid deep splits when tree building, so allocating too many splits to just the linear aspect of \(f(\mathbf{x})\) is troubling. If there is no linear effect, then there really is not a lot of harm with including one extra column into the regression. We also give this extra feature extra weight when splitting. This is coded as a user adjustable variable in this project. If we do not give extra probability to splitting on this feature and there are many features to split on, then including it is waste. If we put too much probability of splitting on just this column, then we can miss out on interactions and other aspects of \(f(\mathbf{x}))\) that are non-linear. In this example, since there are not many other features, we do not adjust the splitting probability.

Click here for full code
suppressMessages(library(agridat))
suppressMessages(library(dplyr))
suppressMessages(library(stochtree))
suppressMessages(library(tidyverse))
suppressMessages(library(here))
suppressMessages(library(gt))
suppressMessages(library(gtExtras))
yield_df = minnesota.barley.yield[minnesota.barley.yield$year>1926,]
weather_df = minnesota.barley.weather
comb_frame = dplyr::right_join(weather_df,
                              yield_df, by=c('site','year'),
                              relationship = 'many-to-many')


comb_frame = comb_frame %>% drop_na()
comb_frame$group_id = paste0(comb_frame$site, '_', comb_frame$gen_name)#comb_frame$year)
group_ids = as.integer(as.factor(comb_frame$group_id))
group_real = comb_frame$group_id

y = comb_frame$yield
comb_frame$mo = as.character(comb_frame$mo)
X2 = data.frame(comb_frame[,c('year', 'mo', 'cdd', 'hdd', 'precip', 'min', 'max')])
X =  data.frame(year = X2$year,
                mo = X2$mo,
                cdd = X2$cdd,
                hdd = X2$hdd,
                precip = X2$precip,
                min = X2$min,
                max = X2$max,
             #   gen = X2$gen,
                stringsAsFactors = T)

# Split data into test and train sets
# From https://andrewherren.quarto.pub/stochtree-user-guide/bart-intro.html
n_test <- round(0.20*nrow(X))
n_train <- nrow(X) - n_test
test_inds <- sort(sample(1:nrow(X), n_test, replace = FALSE))
X_train = X[-test_inds,]
y_train = y[-test_inds]
X_test = X[test_inds, ]
y_test = y[test_inds]
group_ids_train = group_ids[-test_inds]
group_ids_test = group_ids[test_inds]
group_real_test = group_real[test_inds]
group_real_train = group_real[-test_inds]
#group_ids_train = group_ids
#group_ids_test = group_ids
#group_real_train = group_real
#group_real_test = group_real

#X_train = X
#X_test = X
#y_train = y
#y_test = y


df_fit = data.frame(X_train,y_train)

OLS_fit = lm(y_train~., data=df_fit)

OLS_train = predict(OLS_fit, data.frame(X_train))
OLS_test = predict(OLS_fit, data.frame(X_test))


#OLS_train = predict(OLS_fit)

X_train = data.frame(X_train, lm_term = OLS_train)
X_test = data.frame(X_test, lm_term = OLS_test)

#### A log transform is necessary (potentially????)
lm_weight = 1/ncol(X_train) 

bart_fit = stochtree::bart(X_train = X_train,
                           y_train = y_train,
                           #y_train = y_train,
                           X_test = X_test,
                           rfx_group_ids_train = group_ids_train,
                           rfx_group_ids_test = group_ids_test,
                           num_gfr = 100,
                           num_burnin= 25,
                           num_mcmc = 100,
                           random_effects_params = list(model_spec= 'intercept_only'),
                           general_params = list(verbose=F, num_chains=10,
                                                 variable_weights=c(rep(lm_weight/(ncol(X_train)-1),
                                                                        ncol(X_train)-1),
                                                                    lm_weight),
                                                 num_threads=10,
                                                 sample_sigma2_global=F),
                           mean_forest_params = list(num_trees = 150,
                                                     alpha=0.95, beta=2,
                                                     min_samples_leaf = 5,
                                                  #   sigma2_leaf_init = 1e-1,
                                                  #   sigma2_leaf_shape = 5, sigma2_leaf_scale=2,
                                                     max_depth=-1
                           ),
                           variance_forest_params = list(    num_trees = 75,
                                                             alpha = 0.25,
                                                             beta = 3.0,
                                                             min_samples_leaf = 10))
Click here for full code
'%!in%' <- function(x,y)!('%in%'(x,y))
RMSE <- function(m, o){   sqrt(mean((m - o)^2)) }

# predict over 2024 now
pred_frame = X_test



y_hat_preds <- predict(bart_fit, X = pred_frame,
                       rfx_group_ids = group_ids_test,
                       terms = "all", type = "mean")
y_hat_post <- predict(bart_fit, X = pred_frame,
                      rfx_group_ids = group_ids_test,
                      terms = "all")

# Posterior per farm/year
yhats <- y_hat_post$y_hat +
               sqrt(y_hat_post$variance_forest_predictions[,])*
  rnorm(nrow(X_test)*(bart_fit$model_params$num_samples))
#yhats

qm = apply(yhats,
           1,
           quantile,probs=c(.025,.975))


X_test$group_ids =  group_real_test

fits = data.frame(group_ids = X_test[,'group_ids'],
                  posterior = yhats ) %>%
  group_by(group_ids) %>%
  summarize(across(everything(), mean))
# for the mean terms
fits_mean = data.frame(group_ids = X_test[,'group_ids'],
                       posterior = y_hat_post$y_hat ) %>%
  group_by(group_ids) %>%
  summarize(across(everything(), mean))
# repeat for sigma and random effect
fits_sigma = data.frame(group_ids = X_test[,'group_ids'],
                        posterior = sqrt(y_hat_post$variance_forest_predictions) ) %>%
  group_by(group_ids) %>%
  summarize(across(everything(), mean))
fits_rfx = data.frame(group_ids = X_test[,'group_ids'],
                      posterior = y_hat_post$rfx_predictions ) %>%
  group_by(group_ids) %>%
  summarize(across(everything(), mean))

posterior_LB = unlist(lapply(1:nrow(fits), function(q)
  quantile(unlist(fits[q,2:(ncol(fits)-1)]), probs=c(0.025),na.rm=T)))
posterior_UB = unlist(lapply(1:nrow(fits), function(q)
  quantile(unlist(fits[q,2:(ncol(fits)-1)]), probs=c(0.975),na.rm=T)))
# for each row, take the mean
posterior_mean = unlist(lapply(1:nrow(fits_mean), function(q)
  mean(unlist(fits_mean[q,2:(ncol(fits_mean)-1)]), na.rm=T)))
# repeat for sigma
posterior_sigma_mean = unlist(lapply(1:nrow(fits_sigma), function(q)
  mean(unlist(fits_sigma[q,2:(ncol(fits_sigma)-1)]), na.rm=T)))
# Now do quantiles
posterior_sigma_LB = unlist(lapply(1:nrow(fits_sigma), function(q)
  quantile(unlist(fits_sigma[q,2:(ncol(fits_sigma)-1)]), probs=c(0.025),na.rm=T)))

posterior_sigma_UB = unlist(lapply(1:nrow(fits_sigma), function(q)
  quantile(unlist(fits_sigma[q,2:(ncol(fits_sigma)-1)]), probs=c(0.975),na.rm=T)))

# and the random effect
posterior_rfx_mean = unlist(lapply(1:nrow(fits_sigma), function(q)
  mean(unlist(fits_rfx[q,2:(ncol(fits_rfx)-1)]), na.rm=T)))

posterior_rfx_LB = unlist(lapply(1:nrow(fits_sigma), function(q)
  quantile(unlist(fits_rfx[q,2:(ncol(fits_rfx)-1)]), probs=c(0.025),na.rm=T)))

posterior_rfx_UB = unlist(lapply(1:nrow(fits_sigma), function(q)
  quantile(unlist(fits_rfx[q,2:(ncol(fits_rfx)-1)]), probs=c(0.975),na.rm=T)))


plot1 = data.frame(true = y_test, predicted = y_hat_preds$y_hat)%>%
  ggplot(aes(x=true, y = predicted))+geom_point(col='#073d6d', size=0.5)+
  geom_abline(intercept = 0, slope = 1, col='#55ad89', lwd=1)+
  ggtitle('Predicted barley yield vs true barley yield')+
  xlim(0,80)+
  ylim(0,80)+
  theme_minimal(base_size=14)


avg_yield = comb_frame %>%
  group_by(group_id)%>%
  summarize(mean_yield = mean(yield))

plot2 = data.frame(true = avg_yield$mean_yield,
           predicted = posterior_mean,
           LB = posterior_LB,
           UB = posterior_UB)%>%
  ggplot(aes(x=true, y = predicted))+
  #xlim(0,80)+ylim(0,80)+
  geom_abline(intercept = 0, slope = 1, col='#55ad89', lwd=1)+

  geom_errorbar(aes(ymin=LB, ymax=UB), alpha=0.63, col='#d47ca7', width=1.63, size=1.25)+
  geom_point(col='#073d6d')+
  ggtitle('Predicted avg yield for site-gen vs true avg yield')+
  theme_minimal(base_size=14)
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.
Click here for full code
gridExtra::grid.arrange(plot1, plot2, nrow=1)

Model fit
Click here for full code
df_viz = data.frame(site_gen = avg_yield$group_id,
                    true = avg_yield$mean_yield,
                    predicted = round(posterior_mean,3),
                    LI = posterior_LB,
                    UI = posterior_UB,
                    LI_rfx = posterior_rfx_LB,
                    UI_rfx = posterior_rfx_UB,
                    LI_sigma = posterior_sigma_LB,
                    UI_sigma = posterior_sigma_UB,
                    posterior_interval = round(posterior_mean,3),
                    posterior_rfx = posterior_rfx_mean,
                    posterior_sigma = posterior_sigma_mean)
df_viz$full_dist = lapply(1:nrow(fits), function(q)
  fits[q,2:(ncol(fits)-1)])






df_viz[sample(50,8),] %>%
  gt() %>%
  gt_plt_dist(full_dist, type='density', fill_color='#004D40',
              line_color='#f8f9fa', fig_dim=c(10,80))%>%
  fmt_number(decimals=1)%>%
  cols_width(
    site_gen ~ px(250),
    true ~ px(100),
    predicted ~ px(100),
    full_dist ~ px(400),

  ) %>%
  data_color(
    columns = true,
    palette = "PuOr"
  )  %>%
  data_color(
    columns = predicted,
    palette = "PuOr"
  ) %>%
  cols_hide(columns = c(LI, UI, posterior_interval, posterior_rfx, posterior_sigma)) %>%
  cols_hide(columns = c(LI_rfx, UI_rfx))%>%
  cols_hide(columns = c(LI_sigma, UI_sigma))%>%
  tab_style(
    locations = cells_body(columns = site_gen),
    style = cell_text(weight = "bold")
  ) %>%
  tab_header(
    title = md("Posterior summaries for select site-gen pairs."),
    subtitle = "Train on 80% on of site-gen combos, test on 20% and averaging results per site-gen."
  ) %>%
# opt_interactive()%>%
  opt_align_table_header(align = "left") %>%
  tab_options(heading.padding = px(1),
              source_notes.background.color = '#f8f9fa')%>%
  tab_footnote(
    footnote = "Intervals from 2.5 and 97.5% quantiles of full BART posterior, f(x) + random intercept & error uncertainty",
    locations = cells_column_labels(columns = full_dist)
  ) %>%
  tab_options(heading.padding = px(1)) %>%
  tab_options(heading.padding = px(1))%>%
  tab_footnote(
    footnote = "Actual yield.",
    locations = cells_column_labels(columns = site_gen)
  ) %>%
  opt_footnote_marks(marks = c("standard")) %>%
  opt_footnote_spec(spec_ref = "i", spec_ftr = "i")
Posterior summaries for select site-gen pairs.
Train on 80% on of site-gen combos, test on 20% and averaging results per site-gen.
site_gen* true predicted full_dist
Crookston_Trebi 39.8 37.3
Duluth_Trebi 39.6 41.6
Crookston_WisconsinBarbless 39.3 36.6
GrandRapids_Peatland 29.7 30.3
Crookston_Colsess 31.0 33.9
GrandRapids_Odessa 24.7 25.8
Duluth_No474 33.1 34.3
GrandRapids_No474 26.5 26.1
* Actual yield.
Intervals from 2.5 and 97.5% quantiles of full BART posterior, f(x) + random intercept & error uncertainty

Interactive mode.

The following code to replicate Figure A.1. We do not run this code as getting the aesthetics right requires manual editing based on the specific train/test splits. We don’t mess with the aesthetic.

Look at full predictive posterior distributions for different number of replicates. Not run.
#### Final plot ####

counts = X_test %>%
  group_by(group_ids) %>%
  summarize(count = n())
bottom_5 =  X_test %>%
  group_by(group_ids) %>%
  summarize(count = n())%>%
  arrange(count)%>%
  slice(1:5)

top_5 = X_test %>%
  group_by(group_ids) %>%
  summarize(count = n())%>%
  arrange(-count)%>%
  slice(1:5)

selected_group_ids = which(X_test$group_ids %in% c(bottom_5$group_ids,top_5$group_ids))
group_ids_select = unlist(lapply(1:nrow(X_test[selected_group_ids,]),
                        function(p) rep(X_test[selected_group_ids[p],'group_ids'],
                                        ncol(y_hat_post$y_hat))))

fits = data.frame(group_ids = X_test[selected_group_ids,'group_ids'],
                  posterior = yhats[selected_group_ids,]) %>%
  group_by(group_ids) %>%
  summarize(across(everything(), mean))

posterior_long = unlist(lapply(1:10, function(q)fits[q,2:(ncol(fits))]))

group_ids_plot = unlist(lapply(1:10,
                       function(k)rep(fits$group_ids[k],
                                      ncol(y_hat_post$y_hat))))


data.frame(group_ids = factor(group_ids_plot,levels = c(bottom_5$group_ids,
                                                        top_5$group_ids)),
           posterior = posterior_long) %>%
  #group_by(group_ids)%>%
  #summarize(across('posterior', mean))
  ggplot(aes(x = group_ids, y=posterior))+
  xlab('Site-cereal type')+
  ylab("Posterior predictive of average yield by site-cereal")+
  #ylim(70,110)+
  ggdist::stat_gradientinterval(fill='#012296',size=2.75, scale = 0.63,color='#95187D',
                                point_interval='median_qi',
                                alpha=0.72,fatten_point=4,fill_type='segments')+
  ggdist::scale_slab_alpha_continuous(range = c(0, 1))+
  geom_vline(aes(xintercept=5.5), lwd=2, col='#55ad89',lty=2)+
  annotate(geom='text', x=3.25, y=56,
           label = 'Bottom 5 # of samples',
           color='#073d6d',
           linewidth=4,
           size=4,
           fontface='bold',
           hjust=1, vjust=0, lineheight=4)+
  annotate(geom='text', x=8, y=56,
           label = 'Top 5 # of samples',
           color='#073d6d',
           linewidth=4,
           size=4,
           fontface='bold',
           hjust=1, vjust=0, lineheight=4)+
  annotate(geom='rect', xmin=1.5, xmax=3.5, ymin=52, ymax=60,
           fill=alpha('#073d6d', 0.04), color='#073d6d', linewidth=1)+
  annotate(geom='rect', xmin=6.5, xmax=8.5, ymin=52, ymax=60,
           fill=alpha('#073d6d', 0.04), color='#073d6d', linewidth=1)+
  theme_minimal(base_family = "Roboto Condensed", base_size = 16)+
  theme(#axis.text.y = element_blank(),
        #axis.ticks = element_blank(),
        axis.text.x = element_text(angle = 45, hjust = 0.63),
      #  strip.text.y = element_blank(),
        plot.background = element_rect(fill='#f8f9fa', #e6f1f7',
                                       color=NA))#"#f3f3f3",color=NA ))
Figure A.1: The posterior distributions get narrower as the number of replicates increases, as expected. Uncertainty isn’t merely due to limited data (see aleatoric vs epistemic uncertainty).

A.4 Constraints

A.4.1 Great lakes and heteroskedastic noise

The data are from the National Oceanic and Atmospheric Administration (noaa) government site and describe the surface temperature on Lake Superior year over year. Specifically, the following link will bring you directly to the .txt file on noaa’s site. Also have data for the other lakes, and it’s pretty easy to download historical data for any of the great lakes. We zone in Fahrenheit because the granularity is nice.

The first modification to the base BART model is to take the log of the outcome, which has the benefit of ensuring positive predicted surface temperatures and also imposes the assumption of a log-normal distribution of surface temperatures, which is more realistic than normally distributed surface temperatures according to exploratory data visualizations. Incorporating hard constraints into BART is tricky, as we need to grow the trees (and subsequently the forest) with the constraints in mind. This would require accounting for the constraint when trees are accepted/rejected, which means modifying the marginal likelihood calculations. Rejecting forests that fail to meet the criterion post hoc violated the dependent structure of the MCMC, meaning you are not sampling from the stationary distribution (aka the posterior distribution).

So, in order to ensure the outcome is positive (including the posterior predictive intervals) we actually model: \(\log(y) =f(\mathbf{x})+\varepsilon\), which means that \(y\) is log-normally distributed:\(y\sim\exp[N\left(f(\mathbf{x}),\sigma^2(\mathbf{x})\right)]\). This also means the \(\sigma^2(\mathbf{x})\) now represents a multiplicative error and is interpreted on a scale of the percentage of \([f(\mathbf{x})]\). This is because \(y=\exp(f(\mathbf{x})+\varepsilon)=\exp(f(\mathbf{x})\exp(\varepsilon))\), which means the mean forest, the random intercept and error terms are now interpreted as “multiplicative” of one another, were we to try and analyze each term on its own. So a value of 1.01 on this scale means we multiply the other terms in the model by 1.01, i.e. a 1% increase.

You might wonder why not use a log-linear BART prior to model this outcome. This is certainly possible, but is a little overkill. Heteroskedastic BART in stochtree uses the log-linear BART prior from (Murray 2021) since only \(\sigma(\mathbf{x})\) has to be positive, but \(f(\mathbf{x})\) and \(y\) do not need to be. In our case, only \(y\) has to be positive, which makes life easier.

Speaking of heteroskedastic forests, we also include a “heteroskedastic forest” to learn \(\sigma(\mathbf{x})\). The heteroskedastic forests give allow the variance term in the error to be learned from the data. Looking at the temperature data over the course of the year, there is a lot more deviance from the average in the summer compared to the other seasons.

Code to run a heteroskedastic BART regression.
options(warn=-1)
suppressMessages(library(stochtree))
suppressMessages(library(tidyr))
suppressMessages(library(tidyverse))
# Source:
# https://coastwatch.glerl.noaa.gov/statistics/average-surface-water-temperature-glsea/

df_init = read.csv(paste0(here::here(),'/data/Superior_all_year_glsea_avg_s_F.csv'))
df = read.csv(paste0(here::here(),'/data/Superior_all_year_glsea_avg_s_F.csv'))

df$day = df$X
df = df %>%
  dplyr::select(-X)
df = df %>%
  pivot_longer(
    cols = starts_with("X"), # Selects columns that start with "Test"
    names_to = "Year",     # Name for the new column holding the original column names (e.g., "Test1")
    values_to = "Temp"         # Name for the new column holding the values (e.g., 85, 92)
  )


num_gfr <- 25
num_burnin <- 10
num_mcmc <- 200
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F, verbose=F, num_chains=20)
mean_forest_params <- list(num_trees = 25,
                           alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = 5, alpha = 0.95,
                               beta = 3, min_samples_leaf = 30)
df = df %>%
  drop_na()

fit = stochtree::bart(X_train = as.matrix(as.numeric(df$day)),
                      y_train = log(as.numeric(df$Temp)-32),
                      num_gfr = num_gfr,
                      num_burnin = num_burnin,
                      num_mcmc = num_mcmc,
                      mean_forest_params = mean_forest_params,
                      general_params = general_params,
                      variance_forest_params = variance_forest_params
)

yhats2 <- exp(fit$y_hat_train[, ] +
  sqrt(fit$sigma2_x_hat_train[,])*rnorm(length(df$day)*(fit$model_params$num_samples)))
yhats <- matrix(rlnorm(length(df$day)*(fit$model_params$num_samples), fit$y_hat_train, sqrt(fit$sigma2_x_hat_train)), ncol=ncol(fit$y_hat_train),
                nrow = nrow(fit$y_hat_train))
qm = apply(yhats2,
           1,
           quantile,probs=c(.05,.95))


df2 = data.frame(X = as.numeric(df$day),
                 y = as.numeric(df$Temp),
                 bart = exp(rowMeans(fit$y_hat_train)),
                 LI = qm[1,],
                 UI = qm[2,])
df2$X = as.Date(df2$X, format = "%j", origin = "1.1.2014")
index = unlist(lapply(1:length(unique(df2$X)),function(k)
  seq(from=1,
      to=sum(df2$X==unique(df2$X)[k]))))
df2$index = index

plot2 <-
  ggplot(df2, aes(x=X, y=y, group=index)) +
  geom_line(color='#073d6d', size=0.67, alpha=0.22)+
  scale_x_date(date_breaks = "1 month", date_labels =  "%b") +
  geom_line(aes(x=X,
                y=bart+32), color='#55AD89', lwd= 2, lty=1)+
  geom_line(aes(x=X, y=LI+32), color='#d47c17', lwd=0.72,
            alpha=0.95, lty=1)+
  geom_line(aes(x=X, y=UI+32), color='#d47c17', lwd=0.72,
            alpha=0.95,lty=1)+
  geom_hline(aes(yintercept=32), col='#d47ca7',  lwd=1.75)+
  ylab('Temperature (Fahrenheit)')+
  xlab('Time of year')+
  ggtitle('Lake Superior temperature with 90% posterior intervals')+
  theme_minimal(base_family = "Roboto Condensed", base_size = 16)+
  theme(
    #axis.text.y = element_blank(),
    #axis.ticks = element_blank(),
    axis.text.x = element_text(angle = 57, hjust = 0.63),
    # strip.text.y = element_blank(),
    plot.background = element_rect(fill='#ffffff', #e6f1f7',
                                   color=NA),
    plot.title=element_text(hjust=0.5, size=16))
plot2

Without the log-transformation, we could get temperatures below 32 (Fahrenheit) for the predictions, but we know this violates physical law that water cannot be liquid below 32. So the Lake’s minimum temperature is 32 degrees, when it is all ice.

A.4.2 Hacky prior predictive

In the previous section, we discussed the log-transform to incorporate the constraint of positive outcomes. Another (again hacky) way to do this is to use the kernel representation of BART and then incorporate the constraint through rejection sampling. (Geweke 1991). Since the multivariate normal can be sampled without using MCMC methods, we do not need to worry about rejecting samples post-hoc. It is not extremely efficient, as say we have 100 points. That is a 100-dimensional multivariate normal! Based on what our bounds are, we would have to reject the entire sample if just a single point of the 100 from a generated multivariate normal falls outside the bounds. This means we might need to sample MANY times to get a sufficiently large sample from the multivariate distribution2.

2 We will show this with a print statement in the code.

There is a catch with this approach. It isn’t BART anymore. It is starting with the implied distribution from the BART forests, but the constraints reject samples that would be valid BART samples. The BART forest construction, so heralded for its accuracy, has no input into the forests we sample using the rejection method on the multivariate normal kernel BART representation. In other words, the functions we sample using this method are decidedly not BART with a constraint, but they obey the constraint using BART as a building block. REWRITE THIS CONFUSING PARAGRAPH.