13  Conclusion

13.1 Assignments

Choose 1 of the following assignments!

There were a lot of things we wish we got to but did not. Please prepare a lesson (some brief notes and a script, ideally with some data!) on one of these topics: (add references)

  1. Hamiltonian Monte Carlo

  2. CAR (conditional autoregressive models) or Hawkes processes (for example to predict earthquakes)

  3. Survey data methods: (C. Krantsevich et al. 2023) presents a really cool paper (arxiv link) that brings together BART, copulas, AUC, and “fit the fit” type thinking.

  4. Alternatives to regression based causal inference: Another way to account for confounding is through methods that fall under the category of “matching”. Matching is a very intuitive approach that can be very loosely thought of as constructing an appropriate group to compare our observations to.

While so-called “matching” methods are popular, our aversion to using matching methods stems from both numerical (choosing a proper match is a notoriously hard problem in higher dimensions) and philosophical (to create a proper match one must sacrifice either the credibility of the matched pair or diminish power by limiting sample size in the search for a proper match). However, we will provide a broad overview of the matching literature.

As previously mentioned, a common problem in the absence of randomization for treatment assignment is that units that receive treatment and those that do not are systematically different in some way. The premise behind matching is to balance the treatment and control groups based on observed covariates between the two. Note, we are assuming these covariates satisfy conditional unconfoundedness, an assumption maintained throughout this section.

Matching can be employed in numerous ways, including exact matching, where groups are chosen such that the covariates match exactly, an infeasible option given more than a few covariates. Distance metrics are often used as well to define similarities. These distances are then used to define nearest neighbors, where the closest match of a treated unit from the control set is used to impute the missing counterfactual. While these methods have appealing properties, in particular the exact match, but neither method is particularly useful when the dimensionality of the observed data gets large.

As an extension of nearest neighbor matching, one can match with propensity scores (the conditional probability of receiving the treatment), selecting treatment-control pairs with the so-called “caliper”. The caliper (named after the measuring tool) is a common distance metric in propensity score matching and is defined as the maximum distance two units are allowed to be apart to be considered potential matches. If the caliper is small, units are more similar. However, if few units are within this caliper, the effective sample size is smaller, as they are less units to match, which may require a higher caliper. A higher caliper, however, means inherently the treated/control units are less similar that are being matched. These are already issues assuming a well estimated propensity score (or whatever metric is used to match on).

The propensity score, defined \(\pi(\mathbf{x})=\Pr(Z=1\mid \mathbf{x})\), (where \(Z\) is the treatment variable here) has played a central role in causal studies for many decades. (Rosenbaum and Rubin 1983) showed that it is sufficient to balance on the 1-dimensional propensity, which serves as a beneficial dimension reduction. That is, conditioning on the propensity score, \(\pi(\mathbf{x})=\Pr(Z=1\mid \mathbf{x})\) ensures \(\mathbf{x}\text{ independent of } Z\mid \pi(\mathbf{x})\). However, propensity score methods require estimating the propensity function, which is often non-trivial. Additionally, propensity scores base similarities entirely on probability of treatment assignment, whereas information that i predictive of the outcome should also be considered, as in the regression adjustments we did in chapter 8, see (Senn, Graf, and Caputo 2007) for more details.

A common estimator is the “inverse propensity weight” (IPW) estimator, which leverages observed data to estimate average treatment effects. This estimator is defined as:

\[ \overline{\tau}_{\text{IPW}} =\frac{1}{N}\sum_{i=1}^{N}\left(\frac{Y_iZ_i}{\pi(\mathbf{x}_i)}-\frac{Y_i(1-Z_i)}{1-\pi(\mathbf{x}_i)}\right) \]whose expectation can be shown through direct calculation to be the average treatment effect under the three assumptions of exchangeability, positivity, and SUTVA. In practice, \(\Pr(Z = 1)\) is often estimated from the data, but as long as \(E\left(\hat{\pi}(Z = 1)\right) = \Pr(Z = 1)\), the estimator will still be unbiased.

The issues with matching are that a) regression based methods accomplish the same goal and those methods are well developed and b) there is a fundamental trade-off between creating better matches and throwing away more data. To create a more similar group means you need to consign yourself to the fact that there will be less nearby neighbors, using less of your (likely limited) available data. Casting a wider net means you really aren’t comparing similar observations anymore. And this is on top of the difficulties that come from distance methods in higher dimensions.

Using propensity score methods is a stronger alternative, but again, the propensity score is already incorporated in methods like BCF (P. Richard Hahn, Murray, and Carvalho 2020). Additionally, (Herren and Hahn 2020) show in simulation studies that the BCF estimator of the average treatment effect tended to outperform the IPW estimator[^summary-1] in finite sample studies. Matching is useful pedagogically for being an intuitive method and for visualizing the overlap assumption clearly (by looking at the distribution of propensity scores for example).

The fundamental reason regression adjustment is preferable over propensity based methods is that relations to the response variable are considered. As stated in (Senn, Graf, and Caputo 2007) and (P. Richard Hahn and Herren 2022), inclusion of prognostic variables can reduce the variance of an estimator (make it more stable). A prognostic variable is not predictive of the treatment variable, so inclusion into a propensity based predictive model (\(\Pr(D=1\mid \mathbf{x})\)) adds no value. The lack of care for measurable variables affecting only the outcome mean the propensity score estimates will likely be higher variance than regression adjustment based estimators. In regression, since we are predicting the outcome \(Y\) given confounders and potentially prognostic variables, we have an advantage of expecting lower variance estimates of the treatment effect. Of course, including the propensity score into a regression adjustment was shown to be a really helpful tool to reduce the bias associated with machine learning regularization induced confounding (P. R. Hahn et al. 2016), (P. Richard Hahn, Murray, and Carvalho 2020), with the explanation given in chapter 8.

However, given the issues we defined above and the presence of a clear alternative with less holes (response/regression based adjustments), we argue in favor of the alternative over matching and propensity based methods for ATE estimation in observational studies.

  1. Quasi experimental designs:

    There are certain scenarios where we surely cannot meet the criteria needed to use BCF. What do we do then? Well, desperate times call for desperate measures.

    1. Instrumental variable regression:

    2. The front door criterion

  2. Panel data approaches:

    1. Difference in Difference designs:

    Made from a tutorial by PyMC. The arrow represents the “treatment effect”.

A natural companion to a panel data study, is a standard difference in difference (DID) model. This is a popular approach in economics models, pioneered by (Card and Krueger 1994), who studied the effect of minimum wage increase on employment by studying the impact on fast-food employment in neighboring New Jersey and Pennsylvania. New Jersey recently had raised their minimum wage, whereas Pennsylvania did not, so including Pennsylvania as a control and assuming parallel trends (see the figure above) would allow us to estimate the effect of the minimum wage increase in New Jersey.

The difference in difference estimand is defined, in the potential outcome framework, as:

\[ \text{ATT} = E(Y^1_{\text{post}}-Y^1_{\text{pre}}\mid Z=1)-E(Y^0_{\text{post}}-Y^0_{\text{pre}}\mid Z=1) \]

From the graph above, the intuition behind this estimand can be built. A common strategy to estimate this quantity is the two-way fixed effect estimator (commonly known as TWFE), which can be expressed as:

\[ Y_{it}=\alpha_{i}+\beta_{i}z_i+\gamma t_i+\delta(z_i\cdot t_i)+\varepsilon_{it}\]

where \(\alpha\) is a constant describing the mean of the outcome in the pre-intervention units across both groups, \(\gamma\) represents the difference in control group between pre and post time periods, \(\beta\) represents the difference in control and treated group in the pre-treatment period, and \(\delta\) is the average treatment effect on the treated. \(Z_i\) is the indicator of whether or not a treatment (say a policy) is instituted. This regression is typically carried out using ordinary least squares. Here is the code for the plot above, inspired (and in part borrowed) from PyMc (not shown)1:

expand for full code: not run
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import pymc as pm
#import seaborn as sns

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
#az.style.use("arviz-darkgrid")

def outcome(t, control_intercept, treat_intercept_delta, trend, Δ, group, treated):
    return control_intercept + (treat_intercept_delta * group) + (t * trend) +* treated * group)
  
def is_treated(t, intervention_time, group):
    return (t > intervention_time) * group
  
# true parameters
control_intercept = 1.5
treat_intercept_delta = 0.4
trend = -1.25
Δ = 0.75
intervention_time = 0.5

fig, ax = plt.subplots()
ti = np.linspace(-0.5, 1.5, 1000)
# Create grid 
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
ax.grid(which="major", axis='y', color='#1d2951', alpha=0.8, zorder=1)
ax.plot(
    ti,
    outcome(
        ti,
        control_intercept,
        treat_intercept_delta,
        trend,
        Δ=0,
        group=1,
        treated=is_treated(ti, intervention_time, group=1),
    ),
    color="#4c2e5b",
    label="counterfactual",
    ls=":",
)
ax.plot(
    ti,
    outcome(
        ti,
        control_intercept,
        treat_intercept_delta,
        trend,
        Δ,
        group=1,
        treated=is_treated(ti, intervention_time, group=1),
    ),
    color="#55AD89",
    label="treatment group",
)
ax.plot(
    ti,
    outcome(
        ti,
        control_intercept,
        treat_intercept_delta,
        trend,
        Δ,
        group=0,
        treated=is_treated(ti, intervention_time, group=0),
    ),
    color="#455079",
    label="control group",
)
ax.axvline(x=intervention_time, ls="-", color="#631f19", label="treatment time", lw=2)
t = np.array([0, 1])
ax.plot(
    t,
    outcome(
        t,
        control_intercept,
        treat_intercept_delta,
        trend,
        Δ,
        group=1,
        treated=is_treated(t, intervention_time, group=1),
    ),
    "o",
    color="#55AD89",
)
ax.plot(
    t,
    outcome(
        t,
        control_intercept,
        treat_intercept_delta,
        trend,
        Δ=0,
        group=0,
        treated=is_treated(t, intervention_time, group=0),
    ),
    "o",
    color="#455079",
)
ax.set(
    xlabel="time",
    ylabel="metric",
    xticks=t,
    xticklabels=["pre", "post"],
    title="Difference in Differences",
)
# Remove splines. Can be done one at a time or can slice with a list.
#ax.spines[['top','right','left']].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)

# Shrink y-lim to make plot a bit tigheter
#ax.set_ylim(0, 23)

# Set xlim to fit data without going over plot area
#ax.set_xlim(0,0.5)

# Reformat x-axis tick labels
ax.xaxis.set_tick_params(labelsize=11)        # Set tick label size

# Reformat y-axis tick labels
#ax.set_yticklabels(np.arange(0,25,5),            # Set labels again
#                   ha = 'right',                 # Set horizontal alignment to right
#                   verticalalignment='bottom')   # Set vertical alignment to make labels on top of gridline      

ax.yaxis.set_tick_params(pad=-2,             # Pad tick labels so they don't go over y-axis
                         labeltop=True,      # Put x-axis labels on top
                         labelbottom=False,  # Set no x-axis labels on bottom
                         bottom=False,       # Set no ticks on bottom
                         labelsize=11)       # Set tick label size
ax.legend();
# Add in title and subtitle
# Add in line and tag
ax.plot([0.12, .9],                  # Set width of line
        [.98, .98],                  # Set height of line
        transform=fig.transFigure,   # Set location relative to plot
        clip_on=False, 
        color='#1d2951', 
        linewidth=.6)
ax.add_patch(plt.Rectangle((0.12,.98),                 # Set location of rectangle by lower left corder
                           0.04,                       # Width of rectangle
                           -0.02,                      # Height of rectangle. Negative so it goes down.
                           facecolor='#1d2951', 
                           transform=fig.transFigure, 
                           clip_on=False, 
                           linewidth = 0))
  1. Synthetic control methods: The synthetic control method (Abadie and Gardeazabal 2003) assumes you can make a “synthetic” version of your treated unit’s outcomes using a weighted combination of the non-treated unit’s outcomes before the intervention. The picture below is a study of the “synthetic controls” of Colorado to see the impact of marijuana legalization on GDP.

    Figure 13.1

While this seems intuitive, the question of why this is causally valid remains. Identification basically hinges on “if you can find good weights the weights are good” (Rafael Alcantara). See the script below to see a scenario where we’d probably get a causal effect simply due to spurious correlations, replicating a “random walk” DGP from the simulation mini chapter, illustrated in Figure 13.2.

Click here for full code
    set.seed(12024)

    N = 60
    n_unit = 50
    RWs <- ts(replicate(n = n_unit, 
                        arima.sim(model = list(order = c(0, 1 ,0)), n = N-1, mean=0.25)))
    RWs = as.data.frame(RWs)
    RWs_2 = as.data.frame(unlist(lapply(1:n_unit, function(j) RWs[, j])))
    colnames(RWs_2) = 'vals'
    #vals = c(RWs[,1],RWs[,2],RWs[,3],RWs[,4],RWs[,5],
    #         RWs[,6],RWs[,7],RWs[,8],RWs[,9],RWs[,10])

    RW_df = data.frame(value = RWs_2$vals, 
                       group_id = unlist(lapply(1:n_unit, function(j) rep(j, N))), 
                       time = rep(seq(from=1, to=N),n_unit), 
                       treat_indicator = c(rep(0,N/2),rep(1,N/2), rep(0, n_unit*N - N)))
    RW_df = as.data.frame(RW_df)

    highest_cor = order(sapply(1:n_unit, function(k)sort(cor(RWs[1:21,])[,k])[n_unit-1]))[n_unit]

    matplot(RWs[, c(highest_cor, 
                    order(cor(RWs)[,highest_cor])[n_unit-1], 
               order(cor(RWs)[,highest_cor])[n_unit-2]) ], 
            #lty = 1,
            lwd = 3.25,
            type = "l",
            col = c( "#073d6d", '#d47c17', "#55AD89"),
            lty = c(1,2,4),
            bty='l',
            xlab = "Time",
            ylab = "",
            main = "Random walk with drift realizations")  
    abline(v=22, lwd=2,col="black")
    legend('bottomright',
           c('Treated unit', #paste0('Unit: ',highest_cor), 
             'Most correlated', #paste0('Unit: ',order(cor(RWs)[,highest_cor])[n_unit-1]), 
             'Second most correlated'),#paste0('Unit: ',order(cor(RWs)[,highest_cor])[n_unit-2])),
           lty=c(1,2,4),
                    col = c('#073d6d','#d47c17', '#55AD89'), 
                    lwd=2,bty = "n", title='')
Figure 13.2: Coincidentally similar looking trajectories would doom an scm analysis, as this situation is indistinguishable from when the curves are actually generated as “twins” only differing in treatment effect. Otoh, in a regression adjustment, the assumptions are much easier to argue. Accidentally mislabeling confounders (as mediators, colliders, etc) or omitting them altogether (through neglect or limited data collection for example) would be far easier to scrutinize for a recipient/customer of an analysis.

Some other issues are similar to those we have with diff-n-diff. What if there is a shock after the intervention (like a pandemic in 2020). How long do we care about the treatment after intervention? What about “spillover” effects on GDP in Colorado based on other state’s residents coming to Colorado? Figure 13.1 already shows that the found weights do not seem particularly logicial, as a sort of ad-hoc preview of the criticism to follow.

In causal inference lingo, the prognostic effect is the unobserved “synthetic control” and the treatment is observed… in other words, the confounding is captured through the “latent” variable that determines treatment status and outcome… since the synthetic control is the prognostic factor, it impacts the states desire to adopt a policy and its baseline effect in the absence of one.

Logically, assuming that all your confounding arises from some latent factor is iffy at best. States that implement a policy do so because of some underlying factors that also impact the outcome. Banning smoking in restaurants may happen because smoking in restaurants was becoming unpopular! But maybe this is politically oriented, so that there are some states that were already cooling on cigarettes and more likely to adopt the policy, whereas other states have no desire to do so. Overlap is thus a potential issue. However, the issue is actually more severe than just debating whether or not the assumption makes sense in the abstract. Unfortunately, there just is not really a way to verify the assumption that the treated unit’s “control potential outcome” is a sum of the control units with data, because the assumption hinges on variables whose values we cannot observe (the assumption that a treated unit pre-intervention is in reality a weighted sum of the control unit’s outcome values pre-intervention). Even if we get rid of the weighted sum assumption and merely assumed the time-series all arose from a factor model (with \(k\) factors determining \(k\) different types of time series), we still rely on the assumption that the reason the time series curves are similar is because some underlying latent factor we cannot observe generated those curves. See this nice blog post by Kyle Butts for a nice connection between the factor model formulation and synthetic controls (Butts, n.d.).

The crux of the issue can be summarized as follow: if the latent factors don’t exist, you can still find them (random walks can be correlated for example). Even if they do exist, its hard to know if you found the correct ones. So there really is no way to know if the assumption of common (unobserved) factors generating the different time series we observe is not verifiable! Such is the game of relying on unobservable variables in your assumptions2. (Zeitler, Vlontzos, and Gilligan-Lee 2023) illustrate all these points in a succint fashion using the causal diagram framework, see this link here. Figure 13.3 is a screengrab from the paper highlighting this point in case you cannot access the link.

Figure 13.3: Visual exploration of confounding arising from a latent variable.

FWIW, here is a cool application from discord (yes the app) using synthetic controls. (Ben-Michael et al. 2023) have a cool paper showing the connection between synthetic controls and Gaussian processes.

So this method, like some of those in chapter 10, falls into the “well yes, mathematically that’s true, but not necessarily practically true”. The “practically” here refers to the fact that we can conoct a situation where there are other states outcomes that just so happen to be in the convex hull of the treated states (i.e. the weights exist and sum to 1). But, this implies any changes in the post-intervention period are causal and not circumstantial, a big assumption. Further, while the idea is clever, digging further into this assumptions makes us realize it is stronger than usually presented (a theme we’ve become friendly with).

It reminds us of the unsupervised learning section because while the logic may be difficult to square away, there are use cases. There are clusterings/dimension reductions of data that seem reasonable and we showed diagnostics and visuals to help us have faith in them. However, the clusterings or dimension reductions give outputs that are just the most “mathematically convenient” and require a leap of faith by the researcher for further usefulness. In this sense, the synthetic control methods are very similar.

c) Regression discontinuity designs

Click here for full code
    options(warn=-1)
    suppressMessages(library(tidyverse))
    suppressMessages(library(dplyr))
    df_alcohol = read.csv(paste0(here::here(), '/data/df_alcohol_mort.csv'))
    df1 <- df_alcohol %>%
      select(agecell, all, allfitted, internal, internalfitted, external, externalfitted) %>%
      gather(key = "variable", value = "value", -agecell)
    cut_val = 21
    df_alcohol$cut_off <- ifelse(df_alcohol$agecell>=cut_val, 'Over 21','Below 21')
    all_outcome <- df_alcohol%>%
      ggplot(aes(x=agecell, y= all,color=cut_off))+geom_point(size=1.75)+
      geom_smooth(data=df_alcohol[df_alcohol$agecell >= cut_val-1 &
                                     df_alcohol$agecell<=cut_val+1,
      ], method='lm',formula=y~x,se=T, lwd=1.5)+
      xlab('Age')+ylab('Overall Mortality Rates')+

    theme_minimal()+
      theme(plot.title=element_text(hjust=0.5, size=12), legend.position="bottom")+
      scale_color_manual(values=c( '#55AD89', 
                                   '#1d2951'), name='')

    alcohol_outcome <- df_alcohol%>%
      ggplot(aes(x=agecell, y= alcohol,color=cut_off))+geom_point(size=1.75)+
      geom_smooth(data=df_alcohol[df_alcohol$agecell >= cut_val-1 &
                                    df_alcohol$agecell<=cut_val+1,
      ], method='lm',formula=y~x,se=T,lwd=1.5)+
      xlab('Age')+ylab('Alcohol Mortality Rates')+

    theme_minimal()+
      theme(plot.title=element_text(hjust=0.5, size=12), legend.position="bottom")+
      #scale_linetype_manual(values=c('solid', 'dashed', 'solid'), name='')+
      scale_color_manual(values=c( '#55AD89', 
                                   '#1d2951'), name='')
    alcohol_outcome

A BART based RDD, which is an inspired merging of a clever idea in causal inference and the power of the BART prior, was presented in (Alcantara et al. 2024) and can be read on arxiv. The authors modify the BART model to ensure tree splits around the cutoff are populated (on both sides) while making sure that the tree nodes do not encompass too wide of a region around the cutoff. See the figure below.

The width is calibrated using prior predictive studies, which is quite clever.

(Alcantara et al. 2025) improve upon this result and instead of focusing on modifying the tree structure, aim to use leaf regression on each side of the cutoff to determine the causal effect. This is a more straightforward approach that is also much easier to calibrate and test in simulation studies.

  1. BART for panel data: (Wang, Martinez, and Hahn 2024) is a BART based approach to and can also be found on arxiv. This model relaxes the assumption of “parallel trends”, which is quite a strong one. It also allows for estimation of conditional average treatments (meaning treatment effect heterogeneity can be measured). The model is the BCF model with the extension that time is included as a covariate to the prognostic and treatment trees, the treatment trees include the time since treatment, and is multiplied by a Gaussian process that takes time as an input which serves to model the temporal trend of the treatment (shared across units since time of treatment).
  1. Graph/network analyses:

Sequential Monte Carlo for redistricting plans study (McCartan and Imai 2023). flexBART (Deshpande 2024), which is designed as an alternative to one-hot encoding in BART models, which loses meaning for structure between categories when there is in fact an order in the categories. To proceed then, for categorical variables, their model “chooses” which categorical variable to split on uniformly at random, and then assigns the decision rule to be from a random subset from the available options 3. For ordered categorical variables, flexBART uses spanning trees to encode dependence between categorical variables. Graph/network based distances could be useful for road data (on a map) for example. This paper, in a somewhat unrelated note, modified BART to study road risk in Utah (Dahl et al. 2024). It is a very interesting application and modification to BART. Although they did not use a network type analysis, the data at hand suggest it could be worthwhile.

  1. Shalizi social networks & causes from (Shalizi and Thomas 2011) is a super cool paper to read. Homophilly refers to the concept that people tend to group together and contagion says they draw their ideas and personalities from those around them. This paper says the distinction between the two is very difficult to pull apart. Are people living in an echo chamber (a self-fulfilling prophecy) or are they brainwashed?

3 For example, if the decision options are red, blue, or green, a random subset of those is selected, like blue and green.

2 Of course, the assumption we saw in chapter 8 that we observe all confounding isn’t verifiable either. BUT, it seems easier to convince yourself the variables you include are confounders or not and whether or not there are variables you’d consider confounders if you had the resources to go and measure them. Even if you do not have them, you can clarify for future studies which variables to go and find in the data curation stage of that study.

1 The estimator has a term in the denominator which can make it a higher variance estimator.


Make a slide deck on common statistical fallacies, biases, and generally annoying things that can curtail an analysis , how to avoid them, why they are problematic, and what they are. Include information about fallacies/issues from chapter 1 (and more if you have encountered others!) Provide some ways to counter these fallacies and what you would do.


  • Attached is data where the \(\mathbf{X}_t\) (indexed over different times) and \(\mathbf{X}_k\) (over \(k\) additional covariates) are generated from separate factor models and the \(y_t\) are generated as a function of each \(\mathbf{X}\). The data are meant to be a simulated result of player performance in your favorite sports league over time. Build a projection system (predict \(y_t\) into the future given \(X\) and \(t\) values). Do so as follows:

    1. Use a k-nearest neighbors approach where the \(y_t\) are predicted using the nearest neighbors. This is what Nate Silver did with PECOTA wiki, one of the first and best baseball projection systems

    2. Use BART and predict based on the previous data points and other covariates. BART does not really extrapolate well, so keep that in mind. Maybe try (Wang, He, and Hahn 2024) that we will discuss later.

    3. Use an ARIMA model and a prophet model to predict the time series out into the future. Play around with parameters a little, but no need to do much here.



In general, we did not talk too much about the violations of regression assumptions (like homoskedasticity, correlated errors, normal errors, etc.) Below are a list of papers who relax those assumptions for BART (yes we jumped the step for introducing solutions with linear models but oh well!), as well as other special topic BART (and some Gaussian process) papers. Read 1-2 of those and prepare as if you would give a short lecture on them!

original_BART_paper

Review of BART and extensions:

As we have seen, correlated errors can be problematic. Using an OLS estimator for a linear regression model, parameters can have wrong variance estimates. When using BART, estimates of \(E(y\mid \mathbf{X})\) can be badly biased (with poor coverage as well) under extreme error correlation scenarios.

Gaussian processes Correlated Bayesian Additive Regression Trees: (Lu and McCulloch 2023) approach the problem of correlated errors by designing a bespoke BART prior to account for correlation in errors. They also add a Gaussian process term to their regression to account for spatial or termporal errors, i.e.

\[ y(\mathbf{x}_i) = \text{special BART}(\mathbf{x}_i)+\text{GP}(\mathbf{x}_i) \]

They describe a procedure to ensure the BART term and the GP term remain identified when both are estimated simultaneously. The GP can have a specialized covariance for the problem at hand, such as if we assume the errors are correlated temporally, like in a time series.

  • Varying coefficient BART: VCBART: Bayesian trees for varying coefficients (varying coefficient BART) were studied in(Deshpande et al. 2020). This places a different BART prior on different coefficients in the model, which essentially means we take a linear combination of BART priors. The VC BART model is a sum of BART forests, i.e. \(y = \beta_{0}+\beta_{1}(\mathbf{x})z_1+\beta_{2}(\mathbf{x})z_2+\ldots+\varepsilon\). A special case of this type of model are Bayesian Causal Forests (BCF), with just an intercept forest and a forest acting on the treatment variable \(Z\). The \(z\)’s above do not have to be treatment variables (just to avoid confusion). In general, multivariate leaf regression will accomplish a similar goal as VCBART but with easier implementation in stochtree. Rather than place a separate forest on each coefficient, we can perform a basis expansion and the leaves, which can accomplish similar “smoothing” objectives. Depending on the basis we use for the regression, we can study similar outcomes as we could with a VC model. See the Physics informed BART section or the multiple outcomes discussion in chapter 11.

    VC-BART is a useful model if the coefficients on the BART priors warrant analysis (a la a BCF model). If the problem warrants additive forests that each require their own forests (with different input variables and parameters), then this model is great. If the goal is merely prediction, you’d likely be better suited using leaf regressions. Below is an R-implementation:

Open for code to run an additive forest model
options(warn=-1)
suppressMessages(library(stochtree))
suppressMessages(library(tidyverse))
suppressMessages(library(dplyr))
suppressMessages(library(gridExtra))
suppressMessages(library(ggthemes))
# sample size
n = 1000
x1 <- runif(n, 0,2)
x2 <- runif(n,0,2)
x3 <- runif(n,0,2)

z1 <- runif(n, 0,2)
z2 <- runif(n,0,2)
z3 <- runif(n,0,2)
sig = 0.25
coef = 1#0.5
coef2 = -1
# model is beta_1(x)*z1+beta_2(x)*z2+beta_3*z3
y = sin(4*x1)*z1 +  coef*abs(x2-0.5)*z2  + coef2*x3*z3 + sig*rnorm(n)

# Sampling composition
num_warmstart <- 50
num_burnin <- 0
num_mcmc <- 500
num_samples <- num_warmstart + num_burnin + num_mcmc

# Sigma^2 samples
global_var_samples <- rep(0, num_samples)



# Bart hyperparameters
alpha_bart <- 0.95 # BART splitting probability
beta_bart <- 2 # BART depth penalty factor

min_samples_leaf <- 5
max_depth <- 20

num_trees <- 100
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 1/num_trees
# Columns of X
leaf_prior_scale = tau_init*diag(3)
#matrix(c(tau_init), ncol = 1)
nu <- 4
lambda <- 0.5
a_leaf <- 2.
b_leaf <- 0.5
leaf_regression <- T # fit a constant leaf mean BART model
# Standardize outcome
y_bar <- mean(y)
y_std <- sd(y)
resid <- (y-y_bar)/y_std
#outcome <- createOutcome(resid)
outcome <- createOutcome(y)
# Data
X1 <- as.matrix(x1)
X2 <- as.matrix(x2)
X3 <- as.matrix(x3)
Q1 = z1
Q2 = z2
Q3 = z3

X = cbind(X1,X2,X3)
p_X = ncol(X)

p_X1 = ncol(X1)
feature_types1 <- as.integer(rep(0, p_X)) # 0 = numeric
var_weights1 <- rep(1/p_X, p_X)

p_X2 = ncol(X2)
feature_types2 <- as.integer(rep(0, p_X)) # 0 = numeric
var_weights2 <- rep(1/p_X, p_X)


p_X3 = ncol(X3)
feature_types3 <- as.integer(rep(0, p_X)) # 0 = numeric
var_weights3 <- rep(1/p_X, p_X)

custom_func <- function(y, Q1, Q2,Q3, X,
                        num_trees, n, alpha, beta,
                        min_samples_leaf, max_depth){



}


# Common parameters
# use leaf regression for vc model
leaf_dimension1 <- p_X
leaf_dimension2 <- p_X
leaf_dimension3 <- p_X
# use leaf regression for vc model to update bases
outcome_model_type <- 1
nu <- 3
sigma2hat <- 1*(sigma(lm(y~as.matrix(X1))))^2
quantile_cutoff <- 0.9
if (is.null(lambda)) {
  lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu
}
sigma2 <- sigma2hat
current_sigma2 <- 1#sigma2



feature_types <- c(rep(0,p_X))
feature_types1 <- as.integer(c(feature_types))
feature_types2 <- as.integer(feature_types)
feature_types3 <- as.integer(feature_types)

forest_dataset1 <- createForestDataset(X, Q1)
forest_dataset2 <- createForestDataset(X, Q2)
forest_dataset3 <- createForestDataset(X, Q3)

# Random number generator (std::mt19937)
rng <- createCppRNG(012024)


# Sampling data structures
# For now, assume each forest is built with the
# same priors and structures, but different
# potential input at variables

forest_model_config1 <- createForestModelConfig(
  feature_types = feature_types1, num_trees = num_trees,
  num_features = p_X,
  num_observations = n, variable_weights = var_weights1, leaf_dimension = leaf_dimension1,
  alpha = alpha_bart, beta = beta_bart, min_samples_leaf = min_samples_leaf, max_depth = max_depth,
  leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale,
  cutpoint_grid_size = cutpoint_grid_size
)
forest_model_config2 <- createForestModelConfig(
  feature_types = feature_types2, num_trees = num_trees,
  num_features = p_X,
  num_observations = n, variable_weights = var_weights2, leaf_dimension = leaf_dimension2,
  alpha = alpha_bart, beta = beta_bart, min_samples_leaf = min_samples_leaf, max_depth = max_depth,
  leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale,
  cutpoint_grid_size = cutpoint_grid_size
)
forest_model_config3 <- createForestModelConfig(
  feature_types = feature_types3, num_trees = num_trees,
  num_features = p_X,
  num_observations = n, variable_weights = var_weights3, leaf_dimension = leaf_dimension3,
  alpha = alpha_bart, beta = beta_bart, min_samples_leaf = min_samples_leaf, max_depth = max_depth,
  leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale,
  cutpoint_grid_size = cutpoint_grid_size
)
global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init)
forest_model1 <- createForestModel(forest_dataset1, forest_model_config1, global_model_config)
forest_model2 <- createForestModel(forest_dataset2, forest_model_config2, global_model_config)
forest_model3 <- createForestModel(forest_dataset3, forest_model_config3, global_model_config)

# "Active forest" (which gets updated by the sample) and
# container of forest samples (which is written to when
# a sample is not discarded due to burn-in / thinning)
if (leaf_regression) {
  forest_samples1 <- createForestSamples(num_trees, 1, F)
  active_forest1 <- createForest(num_trees, 1, F)
  forest_samples2 <- createForestSamples(num_trees, 1, F)
  active_forest2 <- createForest(num_trees, 1, F)
  forest_samples3 <- createForestSamples(num_trees, 1, F)
  active_forest3 <- createForest(num_trees, 1, F)
} else {
  forest_samples1 <- createForestSamples(num_trees, 1, T)
  active_forest1 <- createForest(num_trees, 1, T)
  forest_samples2 <- createForestSamples(num_trees, 1, T)
  active_forest2 <- createForest(num_trees, 1, T)
  forest_samples3 <- createForestSamples(num_trees, 1, T)
  active_forest3 <- createForest(num_trees, 1, T)
}




# Initialize the leaves of each tree in the first term
active_forest1$prepare_for_sampler(forest_dataset1,
                                   outcome, forest_model1, 1,
                                   0)
active_forest1$adjust_residual(forest_dataset1,
                               outcome, forest_model1, T, F)

# Initialize the leaves of each tree in the second term
active_forest2$prepare_for_sampler(forest_dataset2,
                                   outcome, forest_model2,
                                   1, 0)
active_forest2$adjust_residual(forest_dataset2, outcome,
                               forest_model2, T, F)

# Initialize the leaves of each tree in the third term
active_forest3$prepare_for_sampler(forest_dataset3,
                                   outcome, forest_model3,
                                   1, 0)
active_forest3$adjust_residual(forest_dataset3, outcome,
                               forest_model3, T, F)




gfr_flag = T
veps <- 1
sig_save_bart <- rep(NA,num_mcmc)
for (i in 1:num_samples) {

  # switch over to random walk Metropolis-Hastings tree updates at some point
  if (i > num_warmstart){
    gfr_flag <- F}


  # Sample the BART forests

  forest_model1$sample_one_iteration(
    forest_dataset1, outcome, forest_samples1,active_forest1,
    rng, forest_model_config1, global_model_config,
    keep_forest = T, gfr = gfr_flag
  )

  # The predictions from the first forest on y
  raw_pred1 <- active_forest1$predict_raw(forest_dataset1)
  # Update so that m(x)*f'
  forest_dataset1$update_basis(Q1)

  forest_model1$propagate_basis_update(forest_dataset1,
                                       outcome,
                                       active_forest1)


  forest_model2$sample_one_iteration(
    forest_dataset2, outcome, forest_samples2, active_forest2,
    rng,    forest_model_config2, global_model_config,
    keep_forest = T, gfr = gfr_flag

  )
  # The predictions from the second forest on y
  raw_pred2 <- active_forest2$predict_raw(forest_dataset2)

  forest_dataset2$update_basis(Q2)

  forest_model2$propagate_basis_update(forest_dataset2,
                                       outcome,
                                       active_forest2)



  forest_model3$sample_one_iteration(
    forest_dataset3, outcome, forest_samples3,
    active_forest3,
    rng, forest_model_config3, global_model_config,
    keep_forest = T, gfr = gfr_flag

  )
  # The predictions from the second forest on y
  raw_pred3 <- active_forest3$predict_raw(forest_dataset3)

  forest_dataset3$update_basis(Q3)

  forest_model3$propagate_basis_update(forest_dataset3,
                                       outcome,
                                       active_forest3)



  #forest_model2$propagate_basis_update(outcome)
  # Update variance term
  sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome,
                                        forest_dataset1,
                                        rng, nu, lambda)





  #sig_save_bart[i] <- sampleGlobalErrorVarianceOneIteration(outcome,
  #                                                forest_dataset1,
  #                                                rng, nu, lambda)

  sigma2 <- sampleGlobalErrorVarianceOneIteration(
    outcome, forest_dataset2, rng, nu, lambda
  )
  sigma2 <- sampleGlobalErrorVarianceOneIteration(
    outcome, forest_dataset3, rng, nu, lambda
  )
  sig_save_bart[i] <- sqrt(sigma2)
  #print(i)

}
#hist(sig_save_bart, 50)

preds1 <- forest_samples1$predict_raw(forest_dataset1)[,(num_warmstart+num_burnin):num_samples]#*y_std+y_bar
preds2 <- forest_samples2$predict_raw(forest_dataset2)[,(num_warmstart+num_burnin):num_samples]#*y_std+y_bar
preds3 <- forest_samples3$predict_raw(forest_dataset3)[,(num_warmstart+num_burnin):num_samples]#*y_std+y_bar

preds1_mean <- rowMeans(preds1)
preds2_mean <- rowMeans(preds2)
preds3_mean <- rowMeans(preds3)

yhat = preds1*Q1+preds2*Q2+preds3*Q3
yhat = sapply(1:ncol(yhat), function(i)
  yhat[,i]+
         rnorm(1,0,
    sig_save_bart[(num_warmstart+num_burnin):num_samples][i]))

yhat_mean = preds1_mean*Q1+preds2_mean*Q2+preds3_mean*Q3
qm_y = apply(yhat, 1, quantile, probs=c(0.025, 0.975))
Lab <- expression(y == paste(sin(4*'x'[1])*'*z'[1]+
                                  abs('x'[2]-0.5)*'*z'[2]-
                                  'z'[3]+
                                  N(0,sigma^2)))
plot1 = data.frame(y, preds=yhat_mean, LI=qm_y[1,],
           UI=qm_y[2,]) %>%
  ggplot(aes(x=y, y=yhat_mean))+geom_point(col='#073d6d',size=1.5,
                                        alpha=0.8)+
  geom_abline(aes(intercept=0,slope=1), lwd=1.25, col='#55Ad89')+
  geom_ribbon(aes(ymin=LI, ymax=UI), alpha=0.16)+
  ylab(expression(hat(y)))+
  xlab(expression('y'))+
  annotate(geom='rect', xmin=-5.75, xmax=1.25, ymin=3, ymax=4.5,
           fill=alpha('#073d6d', 0.04), color='#073d6d', linewidth=.5)+
  annotate(geom='text', x=0.75, y=3.5,
           label = Lab,
           color='#073d6d', linewidth=1,
           hjust=1, vjust=0, lineheight=1,
           size=1.85)+
  theme_minimal(base_size=16)+
  theme(legend.position = "bottom",
        panel.background = element_rect(fill='#f8f9fa',
                                       color=NA),
        plot.background = element_rect(fill='#f8f9fa',
                                       color=NA))
#plot1
#plot(y, preds1_mean*Q1+preds2_mean*Q2+preds3_mean*Q3,
#     pch=16, col='#012024', xlab='y', ylab='predicted')
#abline(a=0,b=1, col='#FD8700', lwd=2)

#plot(x1,preds1_mean, pch=16, lwd=2,
#     col='#073d6d', main='term 1', add=T)
#points(x1, sin(4*x1),col='#d47c17', pch=16)

#plot(x2,preds2_mean, pch=16, cex=1.25,
#     col='#073d6d', main='term 2')
#points(x2,cos(4*x2), col='#d47c17', cex=0.75, pch=16)

#plot(x3,preds3_mean, pch=16, lwd=2,
#     col='#073d6d', main='term 3')
#points(x3, coef2*x3, col='#d47c17',  cex=0.75, pch=16)
qm = apply(preds1, 1, quantile, probs=c(0.025, 0.975))
qm2 = apply(preds2, 1, quantile, probs=c(0.025, 0.975))
qm3 = apply(preds3, 1, quantile, probs=c(0.025, 0.975))
plot2 = data.frame(x1, preds=preds1_mean, LI=qm[1,],
           UI=qm[2,]) %>%
  ggplot(aes(x=x1, y=preds))+geom_point(col='#073d6d',size=2,
                                        alpha=0.8)+
  geom_line(aes(x=x1, y=sin(4*x1)), lwd=1.5, col='#55Ad89')+
  geom_ribbon(aes(ymin=LI, ymax=UI), alpha=0.16)+
  ylab(expression(beta[1]))+
  xlab(expression('z'[1]))+
  theme_classic(base_size=16)+
  theme(legend.position = "bottom",
        panel.background = element_rect(fill='#f8f9fa',
                                        color=NA),
        plot.background = element_rect(fill='#f8f9fa',
                                       color=NA))


plot3 = data.frame(x2, preds=preds2_mean,
           LI=qm2[1,],
           UI=qm2[2,]) %>%
  ggplot(aes(x=x2, y=preds))+geom_point(col='#073d6d',size=2,
                                        alpha=0.8)+
  geom_line(aes(x=x2, y=abs(x2-0.5)), lwd=1.5, col='#55Ad89')+
  geom_ribbon(aes(ymin=LI, ymax=UI), alpha=0.16)+
  ylab(expression(beta[2]))+
  xlab(expression('z'[2]))+
  theme_classic(base_size=16)+
  theme(legend.position = "bottom",
        panel.background = element_rect(fill='#f8f9fa',
                                        color=NA),
        plot.background = element_rect(fill='#f8f9fa',
                                       color=NA))

plot4 = data.frame(x3, preds=preds3_mean,
           LI=qm3[1,],
           UI=qm3[2,]) %>%
  ggplot(aes(x=x3, y=preds))+geom_point(col='#073d6d',size=2,
                                        alpha=0.8)+
  geom_line(aes(x=x3, y=coef2*x3), lwd=1.5, col='#55Ad89')+
  geom_ribbon(aes(ymin=LI, ymax=UI), alpha=0.16)+
  ylab(expression(beta[3]))+
  xlab(expression('z'[3]))+
  scale_color_wsj()+
  theme_classic(base_size=16)+

  theme(legend.position = "bottom",
        panel.background = element_rect(fill='#f8f9fa',
                                        color=NA),
        plot.background = element_rect(fill='#f8f9fa',
                                       color=NA))




step_plot = grid.arrange(plot1,plot2, plot3,plot4, nrow=2
                         )


  • Log-linear BART: (Murray 2021) model the log of the outcome by a sum of trees. This is particularly useful for modeling the variance term with a log-linear BART prior, since the variance must be greater than 0, that is: \(\sigma^2(\mathbf{x})>0\). The BART prior can then account for dependence of \(\sigma^2\) on \(\mathbf{x}\). Here is an arxiv link of this excellent paper.

  • Heteroskedastic BART: Uses multiplicative regression trees on the variance term to model heteroskedastic error. This has similar reasoning to the log-linear BART model, since the log model implies positive values (which the variance term must have), and the log of multiplicative trees is again a sum of trees. Multiplied trees necessarily model positive outcomes, which variance must be.

    heteroskedastic BART with multiplicative trees (Pratola et al. 2020).

  • Heteroskedastic Gaussian Processes

    Heteroskedastic Gaussian process regression: This paper (Binois and Gramacy 2021) describes a procedure to model heteroskedasticity with Gaussian processes, if they are your preferred engine for Bayesian functional priors. This interesting blog shows a nice implementation of heteroskedastic GP’s which proceeds by placing a prior on \(\varepsilon_{i}\mid \mathbf{x}_i\mid N(0,\sigma^2_{i})\) where \(\sigma^2\) is modeled by a log-linear regression, i.e. \(\ln(\sigma^2_i) = \alpha+\mathbf{x}_i\beta\) .

  • Dirichlet BART: Instead of uniform priors on which variable to split on, this uses sparsity inducing Dirichlet priors. This is more similar to the stochastic search variable selection method in Chapter 9 than the “fit the fit” approach.

BART Dirichlet variable selection (Linero and Yang 2018). Section 7.1 of (H. Chipman et al. 2014) provides a nice overview of this approach.

  • XBART: This is an algorithm that provides a fast approximation the BART posterior exploration. Rather than changing a tree each one step at a time at each posterior draw, the authors create a new tree at each draw grown recursively with splitting rules defined according to the marginal likelihood of splitting (growing/stopping a tree) implied by BART models, in a similar manner to the BART construction. Which variables to split on and where to cut are found from an application of Bayes rule (the joint probability of a variable being chose to split and where that is chosen). A forest built by boosted trees (each individually grown fully from scratch prior to the next residual tree being built) is called a sweep. A key point is that the tree is grown stochastically, with the decision to keep splitting or stopping evaluated according to a probability distribution that is sampled from. We have seen how to sample from a vector of probability weights in R using the sample function.

    Image from Andrew Herren slide deck from github.

    Leaf parameters are sampled in each sweep and the error variance term is updated per sweep as well after the trees are grown, similar to BART. Additionally, several clever suggestions are used to speed up the process, such as pre-sorting features. (He and Hahn 2023) XBART paper, XBART writeup. The XBART sampling procedure is not the BART posterior, but it is a reasonable approximation. If instead the XBART trees are used as initial guesses (i.e. a data informed prior) for the BART MCMC, then we can really mix the two approaches together in a graceful way. Proper uncertainty plus speedup due to less number of MCMC samples required (and potential improvements in the face of a lot of unnecessary variables).

    Crucially, it is important XBART is “quasi-Bayesian”, in that it still uses priors and evaluates marginal likelihoods as an evaluation criterion for splitting/stop splitting, but samples from the XBART procedure do not yield a valid BART posterior. That being said, XBART does sample parameters between “sweeps” through the forest (where each tree is rebuilt from scratch). XBART certainly owes part of its competitive performance to retaining BART regularization mechanisms, but XBART is also a testament to stochastic sampling. The Bayesian tools are a really nice, convenient, and logical way to sample trees and associated parameters, but it is not the only way!

  • Targeted smoothing BART: This paper introduces a Gaussian process in the leaf nodes for a single covariate. The idea is to construct the trees using all but the one targeted variable, then fit a Gaussian process to the data in that leaf with that variable as the input. Very cool paper. Targeted smoothing BART (Starling et al. 2020). One major benefit of this approach is that the trees are not split based on the Gaussian process likelihood; rather the Gaussian process is fit after the tree structure is determined. This still requires some extra computation, as fitting a GP in every leaf node is still \(\mathcal{O}(N_{\text{obs leaf}}^3)\).

  • GP BART: This paper, in contrast to targeted smoothing BART, fits a Gaussian process using every variable instead of just one. In addition, the trees are built based on the GP likelihood. So at every split, the Gaussian process likelihood has to be considered! Then, in addition, the Gaussian process regression must occur in each leaf node as well. This has some performance improvements, but as the cost of **significantly more compute time! (**As this requires \(\mathcal{O}(N_{\text{obs leaf}}^3*p^3)\) vs \(\mathcal{O}(N_{\text{obs leaf}}^3)\) when just one covariate is used). This paper also includes a rotation option for BART, which puts a prior on whether or not to apply a rotation matrix to the features. GP Bart, arxiv link (Maia, Murphy, and Parnell 2024).

  • Treed-GP

Treed GP (Gramacy and Lee 2008). The idea here is to combine trees and Gaussian processes. A partitions the covariate space, and Gaussian processes are fit into each partition defined by the tree. The tree is built based off of the Gaussian process likelihood. This is a “treed GP”.

  • Local Gaussian Process extrapolation for BART models –> This is an awesome idea:

GP BART extrapolation,arxiv link (Wang, He, and Hahn 2024) graft a Gaussian process to leaf nodes in a BART model. By grafting a GP onto BART we mean the following. After training a BART model with training data, and splitting new test data into the correct leaf bins using the trained BART, we check how “similar” the test points are to the training points that occupy the same leaf. If a test data point has an input value outside the range of the training observations in the same leaf as it, it is considered an “exterior point”. If it is not an exterior point, it retains its original BART prediction. If it is an exterior point, then its prediction comes from a gaussian process regression, with the inputs being the training points in the leaf and the test points being all exterior points.

Excalidraw visual illustration of the GP Bart extrapolation. This is repeated for every leaf node of every tree, and done for every MCMC draw as well.

Excalidraw visual illustration of the GP Bart extrapolation. This is repeated for every leaf node of every tree, and done for every MCMC draw as well.

Notice, the Gaussian process prediction happens after the BART model is built, reducing much of the computational burden4. It is simply used to replace the traditional predictions for points deemed to be in the in the extrapolation zone. In this zone, the Gaussian process is more flexible5. This is a GP’ed Tree, as opposed to the “treed GP”

Illustration of the extrapolation idea.
  • Soft-BART: The idea here is instead of taking “step functions”, that represent the average in each leaf node, a sigmoid is fit within the leaves. The width of the sigmoid makes the boundary “soft”, meaning there is a greater probability for values near the a boundary to be part of a different leaf (dependent on the parameters of the sigmoid). Farther from the boundary, the higher probability that the value will be the mean in the leaf. Mathematically, replace the indicator formulation of a tree, \(\sum_{j}^{\text{$\#$ partitions}}\mu_{j}\cdot \mathbf{1}(\mathbf{x}\in \mathcal{S}_j)\) , with a sigmoid function. So the prediction across a tree is now \(\sum_{j}^{\text{$\#$ partitions}}\mu_{j}\cdot \Psi(\mathbf{x})\), where \(\Psi(\mathbf{x})\) has a value for all \(\mathbf{x}\), see Figure 13.4. Since we are not just adding together step functions, the output in a leaf is thus not just the mean, its the mean times a sigmoid. We have a sigmoid for every partition of the tree, so we add together # partition sigmoids. Notice, this is very similar to a neural network with a sigmoid activation function. So softBART is more akin to a Bayesian neural network, but with the number of hidden neurons (where the sigmoid “activations” happen) not set in stone, but learned adaptively. It seems to work reasonably well, albeit at large computational costs. In the “lower” noise setting, this tends to work really well. However, in higher noise settings, the performance seems to degrade. Additionally, BART, particularly with a larger number of trees, produces practically smooth estimates. The additional smoothing granted from averaging predictions across the posterior draws minimizes the smoothness “issue” inherent to base BART.
Figure 13.4: Click to go to source

The bottom right show “step functions” whereas the other three show different sigmoids which represent more boundary distortion the less like a decision tree they look like. From [@linerobayesian2018], reproduced in [@hill2020bayesian].

paper link, arxiv version (Linero 2022).

  • Density regression BART, fully non-parametric BART (not-hederoskedastic as the error does not depend on the covariates). (Orlandi et al. 2021) develop a full density regression, with BART priors for mean and (heteroskedastic) standard deviation components in a mixture of normals model (with covariate dependence) representation of \(\Pr(y\mid \mathbf{x})\), since heteroskedastic BART models still usually assume normal curves just with varying width at different \(\mathbf{x}\). Similar to the idea of summing multiple normal distributions (generated from latent mean & variance parameters) to model a more flexible distribution, can include a latent variable that explains the skew/kurtosis/shape (i.e. multi-modality) of the distribution. If \(y=f(U)+\varepsilon, U\sim \text{uniform}(0,1), \varepsilon\sim N(0,\sigma^2)\). Integrating (averaging over) out \(U\) yields the marginal distribution of \(y\) :\(\Pr(y)=\int_{0}^{1}\frac{1}{\sigma(u)}\phi_{}\left(\frac{y-f(u)}{\sigma(u)}\right)\text{d}u\). Can include covariates by putting BART priors on \(\sigma(u,\mathbf{x})\) and \(f(u, \mathbf{x})\), which means we split mean and variance trees including this latent variable, and calculate \(\Pr(y\mid \mathbf{x})\) by summing (integrating actually) over the possible values of the latent variable.

  • Adaptive Gaussian processes: This paper combines neural networks and Gaussian processes. deep Gaussian Processes (Sauer, Gramacy, and Higdon 2023) are designed to create non-stationary covariance kernels in a data-adaptive way. The linear layers “warp” the input space that is passed into the Gaussian process regression, ideally then allowing for different regions of the input space to essentially have separate Gaussian process regressions, building in non-stationarity. This is very similar to what ts-BART does. It is also similar to Treed GP (Gramacy and Lee 2008), but that machinery is very expensive. Some slides.

Page 4 on arxiv link on click.
  • Prior on number of trees BART: See this link. They place a truncated Poisson prior on the number of trees. Maybe not super necessary, but it is nice to see a fully Bayesian BART and not having to set the tree number ad-hoc or via cross validation.

  • Random intercept BART: (Tan and Roy 2019).

  • BART for survival analysis: Survival analysis deals with estimating “time to event” data, which is complicated by the fact that some people experience an event before a study ends (such as someone dying before the study ends) or the study ends before the event occurs to every person (such as someone not terminating as an employee during a study timeframe). Respectively, these are known as “left” and “right” censoring. In this world, the outcome is when an event occurs for someone. If this event is observed, the individual is flagged by a binary indicator.

    Censoring is problematic because it complicates simple estimation of the expected time to event outcome. We cannot predict the time to event for those who did not experience the event (these individuals would be “right” censored) based on a learnt model on those who did experience the event, because this estimate could be skewed towards the time to event of those who did experience the event. We could estimate the time to event if we include censoring information into our imputation model. (Sparapani et al. 2023), in their NFT BART model, use a data augmentation scheme to fill in the censored times. They draw the censored values randomly from a normal distribution with a BART mean and variance (where BART is trained on the uncensored data).

    In some sense, the expected time to event is a perfectly fine estimand. As we will see, however, the survival curve is a cumulative distribution function that describes the probability the event occurs before a time \(t\). This is an entire curve and provides a much richer estimand. And, it is estimable in flexible ways as we are about to describe! So may as well go for it over the single number summary.

    To set the stage, let us introduce some terminology. People (or parts or whatever your observations are) who do experience the event (in the case of right censoring) experience it at time \(T_{\text{event}}\) and for those who have yet to experience it we merely observe the last time, \(T_{\text{last}}\). A main assumption going forward is that these two times are independent, i.e. the units who experience the event . The fact the censored group is censored does not provide any new information about them .

    For example, let’s assume we study workforce attrition between 2013 and 2025 at a company. The length of an employee’s career if they terminate if \(T_{\text{event}}\) and the length if they are still working is \(T_{\text{last}}\). In 2019, the company institutes a 60 hour work week minimum. This causes a mass reduction in workers with children, who experience the event and terminate. \(T_{\text{event}}\) and \(T_{\text{last}}\) are no longer independent, as knowing if someone as children gives information about their censoring time. However, conditioning on if employees have children and the year of termination gets around this issue.

    Instead, the estimand of interest becomes the survival curve, which is the probability the time to event occurs after some time \(t\). At a given time, we can calculate the probability of experiencing an event, since both censored and uncensored observations are included in this calculation6. That is, while we do not observe the censored individuals time to event outcome, at every time point we do observe whether or not the individuals have experienced the event or not. However, the survival curve requires estimation of a distribution of the time to event, which is a more difficult task than estimating a mean time to event. While more difficult, estimates of a distribution endow us with much more information! And we can always calculate the mean after estimating the distribution.

    The history of survival analysis can be (poorly) summarized as follows. The Kaplan-Meier estimator (Kaplan and Meier 1958) revolutionized (and initiated?) the field of survival analysis. The Kaplan-Meier estimator is based off a relatively simple idea. Imagine modeling whether or not a person survives at a time \(t\) as a coin flip, with a fails mimicking survival, and a heads not surviving. Now imagine a sequence of (independent) coin flips in a row, with each coin flip representing the probability of survival at the next time point. The cumulative probability of survival becomes the product of the probabilities for each individual coin flip. Surviving ten time points is akin to flipping ten consecutive heads with the 10 potentially differently weighted coins. To estimate the probability for each coin flip, (Kaplan and Meier 1958) assigned each “flip” the probability estimated by the percentage of individuals in the study who experienced the event at the corresponding time. For the first flip, this was the proportion of people who experienced the event in the first time step7. For the 8th flip, the probability was the proportion of people who experienced the event in the 8th time step. We can then plot the survival curves at every time point, with different curves plotted for different subgroups and compare them. Logically, the idea is very sound, but also limited! For example, dealing with covariates involves pair-wise comparisons of different curves, which can become cumbersome quickly.

    The inclusion of covariates provides additional value for multiple reasons. 1) The assumption of independence between censored and uncensored groups is far more plausible after accounting8 for additional information. This assumption is not really “testable” statistically, but it is easier to live with knowing we account for extraneous information. 2) Detailed subgroup analyses after running the survival analysis become available with further covariate breakdowns. We could even make individualized (up to the amount of data recorded on a person) survival curves!

    To more statistically rigorously incorporate additional information, Cox proportional hazard model (Cox 1972) become extraordinarily popular. The Cox model allows for a more comprehensive handling of covariates, but is still quite limited.

    The Cox model is based on the hazard function having the form\(h(t\mid \mathbf{x})=h_0(t)e^{-\sum_{i=1}^{p}\beta_{i}x_i}\). For every unique set of realized covariate values, \(\mathbf{x}_k\), a curve/trajectory as a function of time is produced from this function. There are two troubling assumptions implied. The first is the “proportional hazards” assumption, which states that the hazard ratio for any individual’s hazard curve is constant in time, meaning that every individual curve is some multiple of the \(h_0(t)\) baseline trajectory, and that multiple is the same at every time point. This also implies every individual in a study has a curve that is a multiple of everyone elses. This implies every experiences the maximum hazard at the same time. Additionally, Cox proportional hazard models are still linear with respect to covariates (up to pre-specified transformations/interactions). Those are some pretty strong limitations that should (but probably do not in practice) hinder the usefulness of the model.

    An alternative model that drops the restrictive proportional hazards assumption is the “accelerated failure time” (AFT) models (Kalbfleisch and Prentice 1980). The hazard function is now modeled as\(h(t\mid \mathbf{x})=h_0(t\cdot e^{-\sum_{i=1}^{p}\beta_{i}x_i})e^{-\sum_{i=1}^{p}\beta_{i}x_i}\). Making the baseline hazard a function of covariates multiplied by time means we are no longer beholden to the proportional hazards assumption…nice. But, this comes at a cost. The baseline hazard now has to be specified. So while we now can have curves differ from eachother in both scale and shape, the shape has to belong to a pre-specified class of possible functions. So called “Frailty” models are similar to AFT, but use a random effect multiplier for the baseline function. Sensitivity to model choices is still a big problem in this modeling paradigm.

    Chapter 5 of (N. Krantsevich 2023) and Richard Hahn’s notes present a derivation of the hazard function9, a key estimand in the survival analysis literature, for right censored data. The hazard function (which describes the probability the event happens at a \(t\) given that is has not yet happened). This function is important as knowledge of the hazard function lets us recover survival curves in a relatively easy way. The hazard curve estimand is crucially presented in its “reduced form”, which allows it to be estimated by observed data. In this form, the hazard function is expressed in terms of the probability density function of the time to event data and the probability units are censored or not. Importantly, this facilitates the use of flexible machine learning tools, like say BART.

    The main difficulty of (N. Krantsevich 2023) is that the estimator requires density regression on the observed time to event for estimation of the hazard function. Their solution is to model the densities as log-normal distributions, with heteroskedastic BART priors on the mean and variance terms. That is: \(\text{density of observed time to event}\sim\text{lognormal}\left(t, \mu(\mathbf{x}), \sigma(\mathbf{x})\right)\) where \(\mu(\mathbf{x})\) and \(\sigma(\mathbf{x})\) are estimated by the BART model: \(\log (t) = \mu(\mathbf{x})+\varepsilon;\quad \varepsilon\sim N\left(0,\sigma(\mathbf{x})\right)\). BART models the log of the observed time to event (which is either when the event occurs or the time the study ends) with log-linear forests of (Murray 2021), which ensures the mean and variance terms are on the same scale when drawn from the log-normal distribution10. The log-normal distribution entails:

    1) The support is positive, meaning events have to happen after time 0.

    2) There is a defined skewness and kurtosis, besides the mean and variance term.

    3) The variance is learned for every unit, so we can do more granular studies of survival/hazard curves.

    Ideally, as we will discuss, we’d do a full covariate dependent density regression, but this ain’t a half bad start.

    Estimation of the probability of censoring is also necessary in this procedure. This step can be bypassed by imputing the missing time to event outcomes, although assigning the probability of censoring achieves a similar aim and boils down to a binary classification problem.

    Assuming the observed time can be modeled by a log-normal distribution (which means the log of the observed time should follow a normal distribution) is a big assumption that may not be true11, but otherwise this is a very cool approach.

    A cool dataset to study would (the fictional) workforce retention at IBM dataset, which looks at if employees terminate. If they don’t terminate, they are right censored since we do not measure how long they stay at IBM. (Sparapani et al. 2023) also have an interesting BART survival analysis paper. They use BART to modify the accelerated failure time model, modeling the log of the time to event as the outcome, with a BART prior for the mean, a heteroskedastic BART prior for the variance (Pratola et al. 2020), and a Dirichlet process for the error distribution. The method works great, but is a little more computationally heavy. It is also based on the accelerated failure time model, so they must still estimate the baseline trajectory function, even if it is done with a very flexible BART model.

  • Causal inference! We have already discussed (Hill 2011) (Hill’s BART paper), (P. Richard Hahn, Murray, and Carvalho 2020) (BCF), and accelerated BCF, or XBCF (N. Krantsevich, He, and Hahn 2023). We also mentioned (Alcantara et al. 2024) and (Wang, He, and Hahn 2024) as interesting BART causal inference based methodology. (McCulloch et al. 2021) developed a BART based instrumental variable estimator, reliant on Dirichlet process mixture errors.

(Thal and Finucane 2023), linked here: BART performance at Atlantic Causal Inference Conference competion, showcase the strong BART based performance.

  • SHAMELESS PLUG: Do forecasts of bankrupctcy cause bankruptcy? (Papakostas et al. 2023). This causal question asks if auditors signaling a company is at risk causes them to go bankrupt. Of course, companies who receive a public signal of risk are more risky. But, the mechanism of investors panicking when hearing the auditors caution also may pull back investment, thus expediting the bankruptcy. Further complicating the question is that auditors do not disclose all the data they use to make their decisions. Even if they did, they certainly had some interactions with companies that are impossible to quantify (vibes, culture, etc.) So we cannot simply control for all potential confounding when we know we cannot measure all of it. So, this paper proceeds by performing a sensitivity analysis with different distributions for the hidden confounding.

    Using this R package (SFPSA), try and find data on the following scenario and implement the methods of (Papakostas et al. 2023). The scenario is whether or not being an underdog according to bettor lines causes teams to outperform. This is based off an idea Demetri had seeing this image back in 2018?

The Red Sox seemed to really overperform their betting lines.

11 Prior to seeing any data, plot log-normal distributions with different means and variances if you want to see the set of shapes permitted by this assumption, and then decide if the assumption is appropriate. You also should plot histograms of time at IBM and do so for subgroups of interest. Additionally, see if the predicted log-normal distribution from BART (using dlnorm(time_vector, BART_mean, BART_variance) to see if the predicted curve well fits the observed distribution.

10 The hazard function can be interpreted as the probability a person leaves the study (experiences) the event, in the next time \(t\) (assuming they are still in the study up to that point).

9 The hazard function can be interpreted as the probability a person leaves the study (experiences) the event, in the next time \(t\) (assuming they are still in the study up to that point). Recall, a survival curve is the probability an individual stays in the study past a time \(t\).

8 “Proper” accounting can be done with a sensible machine learner, like BART.

7 Technically, we want the proportion who survived, so we do 1 minus the proportion who experienced the event.

6 The earliest approach to do this is known as a Kaplan-Meier curve (Kaplan and Meier 1958), which essentially plots the observed percentages of the event over time. It is possible to compare different curves based on individuals attributes (for example comparing the curves of athletes and non-athletes).

5 As the BART estimate (being the byproduct of a tree based method) in these zones is just the mean of the nearest leaf node.

4 That being said, the predictions for points that use Gaussian processes still require the heavy GP computational cost. But, the data size for these points is expected to be small, and the process of building the BART forests is not affected at all.



Bayesian Optimization problem

  1. You have been tasked with finding the ONE PIECE! Your have found bits and pieces of the fabled treasure and has coordinates of successful and failed dives. Your task is to guide them on where to search next, with the goal of trying to minimize the number of expensive dives. Take a Bayesian Optimization approach on this fictitious (and supplied) data. Use

    a) A Gaussian Process as your surrogate

    b) BART as your surrogate

    c) Play around with the expected improvement criterion

Physics informed statistical learning

Physics informed modeling can mean multiple things. We will discuss physics informed learning with respect to Gaussian processes and neural networks, but it can mean much more. Generally, we are of the philosophy as using physical reasoning applied to real data is how to approach this type of modeling12. The data obey the laws of physics, so let the data do the talking. A “machine”13 can help learn some of the complicated relations in the data, but the proposed model should utilize those learned relations in a way to answer a meaningful question14. In a fancier way to say the same thing, you want to have a reasonable statistical representation of a system that is parameterized in a physically sensible way.

We let the data guide us (necessary to estimate \(\mu(\mathbf{x})\) and \(\tau(\mathbf{x})\) and de-confound the problem) when its applicable and pull the leash when need be (not just taking a naive estimate of the treatment effect and calling it a day).

That is, instead of fitting data to a pre-defined physics model15, or just fitting a machine learning model16 and trusting it’s output entirely, propose a physically valid process to model the real data. The observed data, which is literally constricted to true laws of physics, not our mental model of the correct physics, should be the backbone of analysis. For a silly example to highlight this tug of war, imagine you are designing a self driving car. Your physical model of the world says people must stop at red lights and go when it turns green. A machine learning model “pattern matches” that people tend to run roads, and subsequently people are cautious at the turn of the green light. You do not want your self driving car to run reds, but also want it to be cautious to not start on green’s right away. In the former, we want to trust the physical model, in the latter, we want to trust the machine learning data learned model. So devising the car means you have to glue the approaches together by building a system where the car follows “rules” obtained from a “physical” model but is still informed by the data where the physical model is limited.

In an actual example, (Yang, Tartakovsky, and Tartakovsky 2018) present a nice example of a this type of modeling. They model a collection of physics experiments as a Gaussian process and generate probabilistic realizations of new experiments based on the empirical mean and covariance of the observed experiments. Pretty clever. Assuming physics experiments are realizations of a probalistic process is very reasonable. Determining the form of that process is up to the researcher. Trusting the data when you have many realizations makes sense. Imposing a structural form17 has its benefits but can also be restricting. [Yang et al. (2019) expand the (Yang, Tartakovsky, and Tartakovsky 2018) paper to incorporate information on both high and low fidelity experiments. They use the plentiful low-fidelity data repeated samples to estimate mean and covariances empirically into a Gaussian process. They use a parameterized (aka traditional) Gaussian process to then model the differences between the generated samples and the high-fidelity experiments to tune the parameters of the Gaussian process.

Anyways, let’s get back to reviewing the literature on Gaussian process based physics modeling. (Swiler et al. 2020) wrote a thorough review, from which we choose a few of the more promising/interesting methods. The main takeaway is that Gaussian processes represent a prior over function space, which allows us to incorporate knowledge of what we expect the system to produce (the expected function) a priori. That is, before we see any data, we can use the “physical constraints” we know to restrict the Gaussian process to certain shapes or ranges. For now, we are assuming one dimensional input, \(t\).

The motivating physics example is the damped harmonic oscillator corrupted by measurement noise:

\[\mathcal{L}u(t) = f(t)=m\frac{\text{d}^2f(t)}{\text{d}t^2}+\mu\frac{\text{d}f(t)}{\text{d}t}+kt=0\]

The data are simulated according to the solution of this equation plus measurement error, \(N(0, 0.10)\).

Click here for full code
#| echo: true
#| code-fold: true
#| code-summary: "expand for full code"
    # analytic solution
    # https://beltoforion.de/en/harmonic_oscillator
    # https:://github.com/benmosely/harmonic-oscillator-pinn/blob/main/Harmonic%20oscillator%20PINN.ipynb
    library(Deriv)
    library(pspline)
    N = 500
    mu = 1
    m = 0.8
    hold = 420
    k = 50
    # make sure k>gamma
    gamma = mu/(2*m)

    t = seq(from=0, to=4, length.out=N)

    sim_eq = function(gamma, k, t){
      w = sqrt(k/m)
      omega = sqrt(w^2-gamma^2)
      phi = atan(-gamma/omega)
      A = 1/(2*cos(phi))
      x = 2*exp(-gamma*t)*A*cos(phi+omega*t)
    }
    y = sim_eq(gamma, k, t)+rnorm(N, 0, 0.1)
    plot(y, pch=16, col='#073d6d',
         xlab='t', ylab='y')
    lines(sim_eq(gamma, k, t),  lwd=4,col='#55AD89')

  1. Construct kernel from solution of ODE/PDE: The first approach is to essentially construct the kernel from the Greens function of the pde/ode using a Mercer basis, see (Albert and Rath 2020). The idea is

\[\mathcal{L}u=0\longrightarrow\mathcal{L}_{t}k(t,t')\mathcal{L}_{t'}^{T}=0 \] where \(u\) is the solution, \(\mathcal{L}\) is the operator, and \(k(t,t')\) is the GP kernel. The kernel can be written (as a Mercer’s series) as:

\[ k(t,t') = \sum_{i=1}^{n_\text{train}}\sum_{j=1}^{n_{\text{test}}}\phi_{i}(t)k(t_i, t_j')\phi_{j}(t') \]

where the \(\phi(\cdot)\) are basis functions (orthogonal to the solution \(u\)). The notion of expressing the covariance kernel using an inner product dates back to the pioneering work of Grace Wahba in the 1970s (Wahba 1973).

So if we can find a solution, we can encode that into the kernel, which does not seem particularly useful. Why would you even need the GP if you had a solution in hand? Maybe this could be useful as a proof of concept as the kernel used for the GP model in a “treed-GP” Gramacy approach (Gramacy and Lee 2008) (or a sum of trees if you want to torture your PC (Maia, Murphy, and Parnell 2024)), where a tree is built based off the GP likelihood. This still runs into the same issue though.

  1. Incorporating constraints into the GP. If we can find the operator \(\mathcal{G}\) such that \(\mathcal{L}_t\mathcal{G}_t=0\) then we can write the resulting GP kernel as \(k_\text{physics}=\mathcal{G}_{t}k_\text{og}\mathcal{G}_t'^{T}\), where \(k_{\text{og}}\) is the original covariance function, probably the squared exponential. This method was introduced in (Jidling et al. 2017). (Lange-Hegermann 2018) provides an algorithm to calculate \(\mathcal{G}\) and (Besginow and Lange-Hegermann 2022) extend this method to systems of linearly homogeneous ODE’s (using Smith normal form algorithms). (Gaffney, Yang, and Ali 2022) use this method for an interesting application with a thermodynamic constraint from a fundamental restriction in plasma physics to obey laws between pressure, energy, and entropy. This is actually pretty useful, especially in physics where curl-free constraints are common ( \(\nabla\times f=0\xrightarrow{\text{implies}}\mathcal{G}=\nabla g\) because \(\nabla \times \nabla g = 0\) ) but the cases where \(\mathcal{G}\) can be found, and or computed easily, are vanishingly small.

  2. Direct implementation of GP: The most promising approach, based off Raissi’s work (Raissi, Perdikaris, and Karniadakis 2017), is to directly encode the operator of the PDE/ODE into the covariance kernel, which is because a linear operator acting on a GP is still a GP. Early references to this include (Solak et al. 2002), who incorporate derivative information into the Gaussian process. Since a derivative is a linear operator, (Solak et al. 2002) note that the derivative of a Gaussian process is still a Gaussian process! They proceed by incorporating the observations and the derivative of the observations into the Gaussian process, aided by the fact that the derivative of common kernels (such as the squared exponential) are differentiable. So they calculate the covariance between the function observations, the function observations and the derivatives (potentially computed numerically), and between derivatives, allowing to jointly model the function to be approximated and its derivatives. This is a natural way to incorporate derivative constraints into a Gaussian process regression, as the extra information from the derivative serves to add more knowledge of the system to the estimation of the function of interest. Section 9.4 of (Williams and Rasmussen 2006) provide a nice explanation of this.

    We can write the physics kernel as:

    \[ k_{\text{physics}}(t,t') = \mathcal{L}_{t}\mathcal{L}_{t'}k_{\text{OG}}(t,t') \]

\(\mathcal{L}_{t}\mathcal{L}_{t'}\) in particular is quite gnarly, and requires a lot of algebraic skill or use of symbolic software. For example, if \(\mathcal{L}_{t}=\frac{\partial^2 }{\partial t^2}\), then \(\frac{\partial^2 }{\partial t'^2}\left(\frac{\partial^2 }{\partial t^2}\left(\theta e^{\frac{(t-t')^2}{\ell}}\right)\right)\). However, it is doable. We then need to choose the hyperparameters of the original GP kernel such that the constraint \(\mathcal{L}_tu=f\) is satisfied. While this is a good idea in theory, the implementation is a pain in the ass. Tuning the hyperparameters is a lot of work, and doesn’t seem to work super well.

Luckily, (Raissi, Perdikaris, and Karniadakis 2017) has a better method than what we were trying to do.

If \(f=\mathcal{L}u\), where \(\mathcal{L}\) is a linear operator that takes \(\mathbf{x}\) as an input, and \(u\) is the solution to the operator that yields the physics constraint \(f\). Recall, due to properties of linear operators and Gaussian processes, \(\mathcal{L}u\) is still a Gaussian process with a valid covariance kernel. (Raissi, Perdikaris, and Karniadakis 2017) then enforce the solution of this transformed GP (which has the “physics” encoded into the covariance by operating \(\mathcal{L}\) on the chosen GP kernel) by incorporating \(f\) into the estimation of \(u\). This is in contrast to what we did earlier where we acted \(\mathcal{L}\) on \(u\) and tried to tune hyper-parameters to ensure that \(\mathcal{L}u=f\) is satisfied. Instead, (Raissi, Perdikaris, and Karniadakis 2017) proceed by jointly learning physical constraints \(y_f\) and the observed outcome \(y_u\) as a function of inputs using a Gaussian process prior on \(u\) and setting up a block covariance matrix to establish the relationship between \(f\) and \(u\).

The contribution allows for enhanced extrapolation outside the training data of the observed outcome \(y_u\) if there is knowledge of physical constraints in the extrapolation zone

(Raissi, Perdikaris, and Karniadakis 2017) consider \(\mathcal{L}_tu=f(t)\) and first place a GP prior on \(u\), with kernel \(k_u\). The linear operator acting on the solution is also a GP, where \(\mu(t)\) is the prior mean from the \(u\) GP.

\[ \mathcal{L}_tu(t)=f(t)\sim \text{GP}\left(\mathcal{L}_{t}\mu(t), \mathcal{L}_{t}\mathcal{L}_{t'}k_{u}(t,t')\right) \]

Given some data for \(u(t)\) (the \(t's\) of the damped harmonic oscillator) and data for \(f\) (the calculated \(f\) from the input data using the damped harmonic oscillator equation). We calculate the derivatives numerically using this stackoverflow answer, approximating them with splines18. Theoretically, \(f\) should be zero at every training point.

Numerical calculation of integral using spines
# https://stackoverflow.com/questions/61282256/numerical-derivative-in-r
# Calculate derivatives using spline smoothing
deriv_func = function(x,y){




#For the first derivative:
deriv1st <- predict(sm.spline(x, y, norder=7), x, 1)

#plot(x,deriv1st)

#For the second derivative:
deriv2nd <- predict(sm.spline(x, y, norder=7), x, 2)

#plot(x,deriv2nd)
return(list(deriv1st, deriv2nd))
}

extra_feature_train = m*deriv_func(t, y)[[2]][,1]+
  mu*deriv_func(t, y)[[1]][,1]+
  k*sim_eq(gamma,k,t)
#plot(extra_feature_train)

(Raissi, Perdikaris, and Karniadakis 2017) go one step further and create a joint GP for \(u\) and \(f\), of the form:

\[ \begin{pmatrix} u(t_1)\\ f(t_2) \end{pmatrix}\sim \text{GP}\left(\begin{bmatrix} \mu(t_1)\\ \mathcal{L}\mu(t_2) \end{bmatrix}, \begin{bmatrix} k(t_1, t_1')+\sigma_2\mathbf{I}_{n_u}&\mathcal{L}_{t'}k(t_1, t_2')\\ \mathcal{L}_{t}k(t_2, t_1')&\mathcal{L}_{t}\mathcal{L}_{t'}k(t_2,t_2')+\sigma_2\mathbf{I}_{n_f} \end{bmatrix}\right) \]

where \(k\) is the squared exponential or whatever.

The idea would be with this kernel in hand to use it as the kernel for the GP grafted onto the leaf nodes like Maggie and Richard did (Wang, He, and Hahn 2024). Since this is a multioutput GP, tinkering (if at all possible) would be required. Same goes for the “building the tree wrt the GP” approach.

  1. Raissi (shocker) also suggested a nifty idea for GP’s in a numerical integration scheme (Raissi, Perdikaris, and Karniadakis 2018) (with an arxiv link). (Raissi and Karniadakis 2018) (another link) build off those ideas, noting that \(\mathcal{L}_{t}h^{n}=h^{n-1}\), which builds off Euler’s idea of integrating that \(h^n=h^{n-1}+\Delta t*\frac{\text{d}h}{\text{d}t}\). Successive Euler steps are joint Gaussian processes:

    \[ \begin{pmatrix} h^n\\ h^{n-1} \end{pmatrix}\sim\text{GP}\left(0, \begin{bmatrix} k^{n,n}&k^{n,n-1}\\ k^{n-1,n}&k^{n-1,n-1} \end{bmatrix}\right) \]

    where \(k^{n,n}=k\), \(k^{n,n-1}=\mathcal{L}_{t'}k\), \(k^{n-1,n}=\mathcal{L}_{t}k\), and \(k^{n-1, n-1}=\mathcal{L}_t\mathcal{L}_{t'}k\).

    Maybe a BART kernels could be used for \(k\), but, as far as we can tell, eliminates the use of the BART kernel since we cannot analytically compute the derivatives needed. (Han et al. 2022) present a similar model to solve time dependent PDE’s but for biological systems.

  2. Empirical GP covariance/mean: The final nifty thing (thanks to the review paper (Swiler et al. 2020)) was to assume the results from multiple available physics simulation runs were realizations from a Gaussian process, and to fit the mean and covariance functions empirically. They are then to be the inputs to a multivariate normal, from which many new simulations can be drawn. I’m not sure physicists will love this as they cannot interpolate (or extrapolate) unless they write down new model for interpolation.

Drawbacks: Like every modeling idea, there are always downsides. (Mohan, Chattopadhyay, and Miller 2024) present a really great investigation into the issues with using “neuralPDE” architectures. Neural PDE’s can trace their roots back to “neuralODEs” (Chen et al. 2018). We provide a brief explanation following this blog. Akin to Euler’s method of discretizing a continuous derivative relationship, \(\frac{\text{d}y}{\text{d}t}\) through time step jumps, \(h_{t+1}-h_{t}=f(h_t, \theta_{t})\), where \(f(h_t,\theta_{t})\) is represented by a hidden layer in a neural network. Neural ODE’s make the progression through the hidden layer representations continuous. This is in slight contrast to PINNs, who include the ODE/PDE terms in the loss function, whereas NeuralODE/PDE model the time derivative of the evolution of a system with a neural network directly. However, this requires specifying a discretization method, which (Mohan, Chattopadhyay, and Miller 2024) explore in depth. In particular, they find poor extrapolation and remarkable sensitivity of neuralPDEs to errors in large physics simulations. The authors show that neuralPDE/ODE are “systematically biased” due to learning “the artifacts in the simulation train- ing data arising from the discretized Taylor Series truncation error of the spatial derivative”. Really fascinating and worth a read. We’ll leave you with this excerpt from the discussion:

Finally, our effort is restricted to differentiable programming models where N N is part of a model with numerical schemes. Yet, regardless of the type of ML model, our central statement in this work still holds: The ground truth from PDE simulations is associated with a TS truncation error tied to the numerical discretization and initial conditions, which an ML model implicitly learns. We must now pose a larger question. Can we mathematically represent the error from discretized ground truth on any ML model before training? A rigorous study quantifying the errors incurred by this in the learning process would yield valuable insight into the inherent limits to how much a model can generalize, and aid in uncertainty quantification and trustworthiness. This analysis is complicated because, as seen from this work, it does not have a generic analytic form and is instead dependent on the equation, discretization, and the model. The problem is further compounded when we consider SciML foundation models, which employ large datasets from various numerical solvers. Our analysis shows the blanket assumption of “ground truth” extended to this diverse dataset - without seriously considering the mathematical idiosyncracies of the solvers that generated them - can negatively impact ML models. A parallel exists in the computer vision community, where images can be contaminated by a tiny number of pixels that can cause a model to misbehave without explanation. This idea is widely known as an adversarial attack [43, 44], where datasets with “invisible” (at least to human eyes) and hard-to-find numerical artifacts poison [45, 46] high-quality datasets to destroy model performance.

Our investigation unfortunately raises the concern that adversarial attacks (often unintentional) can also be a factor in large PDE datasets generated by numerical schemes that are unknown to the SciML developer, as is often the case in founda- tion model research. We have shown that even for a problem as simple as the Burgers equation, the solution generated by two solvers with different finite difference schemes are identical - but the “invisible” Taylor Series truncated terms can make the model learn different quantities and undergo catastrophic failure. Further still, these datasets can come from various families of numerics: finite difference, spectral, finite element, each with their own numerical approximations, making the error analysis much more complex than that presented in this work. The computational physics community has long considered these factors when designing numerical solvers for verification and validation. There is a long history of numerical analysis that studies the errors of a numerical method. For example, Richardson extrapolation [47, 48], developed for engineering models of bridges, can be thought of as one of the earliest attempts at addressing this gap. The approach estimates both the form of the true solution (“ground truth”) and the truncation error. This idea may perhaps enable more sophis- ticated study in NeuralPDEs. If we are to build robust and reliable NeuralPDE models for science, we must extend these models the same rigor [49] and analysis, and this work is a modest step in that direction.

PINN’s also have well known extrapolation issues (cite).

Summary: The basic idea of (Raissi and Karniadakis 2018) is to regularize a neural network prediction to a physics model by taking advantage of the differential form (granted by the chain rule) of the loss function in common (multilayer perceptron) neural network architectures. Additionally, parameters of the physics model can be learned by training a multi-output neural network. Gaussian process based physics informed modeling essentially constrict the realizations/samples of a Gaussian process posterior to the known physics, i.e. restricting non-physical realizations of the GP from being sampled or making sure realizations follow the constraint of a physics equation.

BART based physics learning seems like an interesting approach. Both neural networks and GPs may not best equipped for discontinuous constraints or solving for constants that vary over time (or with respect to some other covariate). For example, phase changes may cause the physics to be different in different areas of covariate space. BART being a well regularized tree ensemble method should be able to adapt to this situation.

One thing that is lost in the literature is how much better these methods work than simply treating the physics as a “reference” function. That is, treat the physics model as a “mean term” in a Gaussian process, which can be subtracted off, and then let your ML model learn the residuals.

Another concept we talked about were “surrogate” models. The Bayesian prior on the function, like BART, serves as a proxy for the physics, which then enables uncertainty quantification in a much more computationally efficient way than simply re-running thousands of expensive simulations. Or, worse, changing one variable at a time while keeping the rest constant…when we know the systems that generate the data involve variables that are coupled :(.

Cool example of physics in sports statistical modeling.

Physics informed BART

It was always gonna end with BART.

18 Although the derivatives can be calculated analytically for the damped harmonic oscillator, with some help from Wolfram alpha and the known solution in hand, the following script is a nifty example of how to calculate derivatives using splines.

17 For example by specifying a prior mean and covariance kernel

16 We saw that PINNs (Raissi and Karniadakis 2018) were a nice compromise here between the neural network and the physics model.

15 Which is still useful if you are really confident in the physics.

14 Note the similarity to the Bayesian causal forest (BCF) approach, which provides principled and effective machinery for causal effect estimation. \(Y=\mu(\mathbf{X})+\tau(\mathbf{X})Z+\varepsilon\). We ask a reasonable (and actionable) question, such as does sleeping 1 extra hour a night cause a person with certain covariates (say a 28 year old cool guy) to live longer. We include confounding covariates based on our knowledge of the systems involved in sleep and mortality. We had previously learned a naive estimate based solely on the data in one treatment level versus another would be incorrect because of the confounding, necessating our inclusion of the covariates to control for. Finally, we place BART priors on \(\mu(\mathbf{X})\) and \(\tau(\mathbf{X})\). BART is a machine learning tool with clever regularization that avoids overfitting and can learn the complicated relations between the confounding (and moderating & prognostic) covariates included in \(\tau(\mathbf{X})\) and \(\mu(\mathbf{X})\).

13 Another point to consider here is we want a good machine learning model. One that will overfit or underfit does not capture the true relations/patterns in the data and is of little use.

12 My view on this was influenced by talented former colleagues Richard Hahn and JJ Ruby.

Teach us about your field and find some data and something from the class that might be relevant. Some suggestions:

a) Physics:

Physics informed priors: Using a Bayesian approach to reconstruct the error analysis of an experiment. The priors chosen reflect the known physics and provide a better quantification of the uncertainty associated with the experimental results. (Ressel et al. 2022)

PINN application, is an application of a Physics informed neural network, subtracting off a known physics model and modeling the residuals. (Gaffney, Yang, and Ali 2022).

b) Education:

c) Psychology:

d) Health

Write up a report on two of the following (or submit your own data science project that should belong on here) data science projects that were important .

The qualifications are that the method/application used must have been reliant on boatloads of data in some sense.

  1. iPhone keyboard: This coldfusion video is excellent. The original iPhone was almost derailed because creating a keyboard as small as the original screen was with the touch screen technology at the time proved an engineering nightmare. The solution came from data science. Apple was able to enlarge certain characters on the keyboard (without the user seeing) based on the probability of the next letter to be typed based on past patterns.

13.2 Bonus

Code to run a survival analysis. Using the attrition data from IBM. This is a cool dataset to study would (the fictional) workforce retention at IBM dataset, which looks at if employees terminate. If they don’t terminate, they are right censored since we do not measure how long they stay at IBM. The code below creates interactive plots to look at survival and hazard curve for a subset of IBM workers, as well as aggregating by levels of current job satisfaction.

Code for a survival analysis using BART.
options(warn=-1)
suppressMessages(library(stochtree))
suppressMessages(library(dplyr))
suppressMessages(library(tidyverse))
suppressMessages(library(caret))
#library(plotly)
suppressMessages(library(modeldata))
suppressMessages(library(reactablefmtr))
suppressMessages(library(dataui))







# Attrition dataset
data(attrition)

# The censored outcome
w = as.numeric(attrition$Attrition) - 1

n = length(w)

y = log(attrition$YearsAtCompany+1e-3)

# Treatment is "work life balance"

z = ifelse(attrition$OverTime%in%'Yes',1,0)
# Unorder certain features
attrition$Education = factor(attrition$Education, order=F)
attrition$JobInvolvement = factor(attrition$JobInvolvement, order=F)
attrition$JobSatisfaction = factor(attrition$JobSatisfaction, order=F)
attrition$PerformanceRating = factor(attrition$PerformanceRating, order=F)
attrition$StockOptionLevel= factor(attrition$StockOptionLevel, order=F)
attrition$WorkLifeBalance= factor(attrition$WorkLifeBalance , order=F)
# Covariates
X = attrition[, c('Age', 'DistanceFromHome', 'Gender', 'Education',

                  'Department', 'HourlyRate',
                  'JobInvolvement', 'JobSatisfaction',
                  #'MaritalStatus','NumCompaniesWorked',
                  #'YearsInCurrentRole', 'YearsSinceLastPromotion',
                  'PercentSalaryHike',
                  'PerformanceRating',
                  #'StockOptionLevel',
                  'WorkLifeBalance')]

one_hot = dummyVars(" ~ .", data=X)

X_new = data.frame(predict(one_hot, newdata=X))


cat_cols = colnames(X_new)[c(3:12, 14:21, 23:28)]

for (col in cat_cols) {
    X_new[,col] <- factor(X_new[,col], ordered = F)
 }


# Check probit model
samp_size = floor(0.75*n)

set.seed(123)
train_ind = sample(seq_len(n), size=samp_size)
X_train = X_new[train_ind,]
X_test = X_new[-train_ind,]
w_train = w[train_ind]
w_test = w[-train_ind]
num_mcmc = 110
bart_probit = bart(X_train = as.data.frame(X_train),
                   y_train = 0.25*(2*as.numeric(w_train) - 1),
                   X_test = as.data.frame(X_test),
                   num_gfr = 20,
                   num_mcmc = num_mcmc,
                   num_burnin=100,
                   mean_forest_params = list(min_samples_leaf=1, max_depth=20,
                                             alpha=0.95, beta=2,
                                             sigma2_leaf_shape = 6,sigma2_leaf_scale=0.25,
                                             num_trees=100),
                   general_params = list(verbose=F,
                                         num_chains=5,
                                         keep_every=1,
                                         sigma2_global_init=1,
                                         sample_sigma2_global=F))

#pROC::auc(w_test,rowMeans(pnorm(bart_probit$y_hat_test)))


# For inference
bart_probit = bart(X_train = as.data.frame(X_new),
                   y_train = 0.25*(2*as.numeric(w) - 1),
                   num_gfr = 50,
                   num_mcmc = num_mcmc,
                  mean_forest_params = list(min_samples_leaf=1, max_depth=20,
                                             alpha=0.95, beta=2,
                                             sigma2_leaf_shape = 6,sigma2_leaf_scale=0.25,
                                             num_trees=100),
                   general_params = list(verbose=F,
                                         num_chains=5,
                                         keep_every=1,
                                         sigma2_global_init=1,
                                         sample_sigma2_global=F))

bart_probit_z = bart(X_train = as.data.frame(cbind(X_new,z)),
                     y_train = 0.25*(2*as.numeric(w) - 1),
                     num_gfr = 50,
                     num_mcmc = num_mcmc,
                    mean_forest_params = list(min_samples_leaf=1, max_depth=20,
                                      alpha=0.95, beta=2,
                                      sigma2_leaf_shape = 6,
                                      sigma2_leaf_scale=0.25,
                                      num_trees=100),
                     general_params = list(verbose=F,
                                           sigma2_global_init=1,
                                           sample_sigma2_global=F))

pred_probit = predict(bart_probit, as.data.frame(X_new))
pred_probit_z = predict(bart_probit_z, as.data.frame(cbind(X_new,z)))
#pred_probit_z = predict(bart_probit_z, as.data.frame(cbind(X_new,
#                                                     rep(z, nrow(X_new)))))
pi_w = rowMeans(pnorm(pred_probit$y_hat))
pi_w_z = rowMeans(pnorm(pred_probit_z$y_hat))
# For the hetero forests
bart_params_mean = list(num_trees = 100,
                        alpha = 0.95, beta = 2,
                        min_samples_leaf = 5,
                        sample_sigma2_leaf = F)
bart_params_variance = list(num_trees = 20,
                            alpha = 0.95, beta = 2,
                            min_samples_leaf=20,
                            sample_sigma2_leaf=F)



# predict y = log(T_obs) after training with separate BART models
# This is a T-learner, which fits on all X and with w as a covariate
# and tests on W=1 or W=0
# S-learner trains on W==1 or W==0 then tests on W=1

Q_bart = stochtree::bart(X_train = as.data.frame(cbind(X_new,w, pi_w)),
                           y_train = y,
                           general_params = list(sample_sigma_global=T, verbose=F, num_chains=5,
                                                 keep_every=1),
                         num_gfr = 20,
                         num_mcmc = num_mcmc,
                           mean_forest_params = bart_params_mean,
                           variance_forest_params = bart_params_variance)



Q_bart_t = stochtree::bart(X_train = as.data.frame(cbind(X_new,w,z, pi_w)),
                             y_train = y,
                             general_params = list(sample_sigma_global=T, num_chains=5,
                                                   keep_every=1),
                           num_gfr = 20,
                           num_mcmc = num_mcmc,
                             mean_forest_params = bart_params_mean,
                             variance_forest_params = bart_params_variance
                             )


Q1_pred = predict(Q_bart, as.data.frame(cbind(X_new,w=rep(1,
                                            nrow(X_new)), pi_w=pi_w)))
Q0_pred = predict(Q_bart, as.data.frame(cbind(X_new,w=rep(0,
                                                 nrow(X_new)), pi_w=pi_w)))
Q0 = rowMeans(Q0_pred$y_hat)
Q1 = rowMeans(Q1_pred$y_hat)

Q0_sigma = sqrt(rowMeans(Q0_pred$variance_forest_predictions))
Q1_sigma = sqrt(rowMeans(Q1_pred$variance_forest_predictions))



#### For the causal inference

Q0_t_pred = predict(Q_bart_t, as.data.frame(cbind(X_new,w=rep(0,
                                                   nrow(X_new)),
                                         z=rep(1,
                                             nrow(X_new)), pi_w=pi_w)))
Q0_c_pred = predict(Q_bart_t, as.data.frame(cbind(X_new,w=rep(0,
                                                   nrow(X_new)),
                                         z=rep(0,
                                             nrow(X_new)), pi_w=pi_w)))
Q0_t = rowMeans(Q0_t_pred$mean_forest_predictions)
Q0_c = rowMeans(Q0_c_pred$mean_forest_predictions)
Q1_t_pred = predict(Q_bart_t, as.data.frame(cbind(X_new,w=rep(1,
                                                   nrow(X_new)),
                                        z= rep(1,
                                             nrow(X_new)),pi_w=pi_w)))
Q1_c_pred = predict(Q_bart_t, as.data.frame(cbind(X_new,w=rep(1,
                                                   nrow(X_new)),
                                         z=rep(0,
                                             nrow(X_new)),pi_w=pi_w)))
Q1_t = rowMeans(Q1_t_pred$mean_forest_predictions)
Q1_c = rowMeans(Q1_c_pred$mean_forest_predictions)

Q0_sigma_t = sqrt(rowMeans(Q0_t_pred$variance_forest_predictions))
Q0_sigma_c = sqrt(rowMeans(Q0_c_pred$variance_forest_predictions))
Q1_sigma_t = sqrt(rowMeans(Q1_t_pred$variance_forest_predictions))
Q1_sigma_c = sqrt(rowMeans(Q1_c_pred$variance_forest_predictions))



t = attrition$YearsAtCompany#+1
p_t = dlnorm(t, Q1, Q1_sigma)

# Equation 5.19
denom = plnorm(t,Q1,Q1_sigma)*pi_w + ((1-plnorm(t, Q0, Q0_sigma))*(1-pi_w))

h_t = p_t/denom

# Generate survival curves
t_time = 100
t_max = max(t)
t_max = 25
t_val = seq(from=min(t),to=t_max, length.out = t_time)

surv_vals = rep(NA,t_time)
S_t = matrix(NA, nrow=n, ncol=t_time)

for(j in 1:n){
  for(q in 1:t_time){
    tryCatch({
    surv_vals[q] = integrate(function(a)
      (dlnorm(a, Q1[j], Q1_sigma[j])*(pi_w[j]))/
        ((1-plnorm(a,Q1[j],Q1_sigma[j]))*pi_w[j] + ((1-plnorm(a, Q0[j], Q0_sigma[j]))*(1-pi_w[j]))),
      lower=0, upper=t_val[q],subdivisions=200,stop.on.error=F)$value
    },
    error=function(e){})

  }
  S_t[j,] = exp(-surv_vals)
}



t_time = 100
t_val = seq(from=min(t),to=t_max, length.out = t_time)
surv_vals_treat = rep(NA,t_time)
S_t_treat = matrix(NA, nrow=n, ncol=t_time)
surv_vals_control = rep(NA,t_time)
S_t_control = matrix(NA, nrow=n, ncol=t_time)

for(j in 1:n){
  for(q in 1:t_time){
    tryCatch({
    surv_vals_treat[q] = integrate(function(a)
      (dlnorm(a, Q1_t[j], Q1_sigma_t[j])*pi_w[j])/
        ((1-plnorm(a,Q1_t[j],Q1_sigma_t[j]))*pi_w[j] + ((1-plnorm(a, Q0_t[j], Q0_sigma_t[j]))*(1-pi_w[j]))),
      lower=0, upper=t_val[q],subdivisions=200,stop.on.error=F)$value},
    error=function(e){})
    tryCatch({
    surv_vals_control[q] = integrate(function(a)
      (dlnorm(a, Q1_t[j], Q1_sigma_t[j])*pi_w[j])/
        ((1-plnorm(a,Q1_c[j],Q1_sigma_c[j]))*pi_w[j] + ((1-plnorm(a, Q0_c[j], Q0_sigma_c[j]))*(1-pi_w[j]))),
      lower=0, upper=t_val[q],subdivisions=200,stop.on.error=F)$value},
  error=function(e){})

  }
  S_t_treat[j,] = exp(-surv_vals_treat)
  S_t_control[j,] = exp(-surv_vals_control)
}


# hazard ratios over ALL times
# Causal hazards
h_t_individual_treat = matrix(NA, nrow=n, ncol=t_time)
h_t_individual_control = matrix(NA, nrow=n, ncol=t_time)

for (j in 1:n){
  for (q in 1:t_time){
    h_t_individual_treat[j,q] =
      (pi_w[j]*dlnorm(t_val[q], Q1_t[j], Q1_sigma_t[j]))/
      ((1-plnorm(t_val[q],Q1_t[j],Q1_sigma_t[j]))*pi_w_z[j] +
         ((1-plnorm(t_val[q], Q0_t[j], Q0_sigma_t[j]))*(1-pi_w_z[j])))

    h_t_individual_control[j,q] =
      (dlnorm(t_val[q], Q1_c[j], Q1_sigma_c[j])*pi_w[j])/
      ((1-plnorm(t_val[q],Q1_c[j],Q1_sigma_c[j]))*pi_w_z[j] +
         ((1-plnorm(t_val[q], Q0_c[j], Q0_sigma_c[j]))*(1-pi_w_z[j])))

    }
}

# make long format
colnames(h_t_individual_control) = unlist(lapply(1:t_time, function(h)
  paste0('time_', h)))
colnames(h_t_individual_treat) = unlist(lapply(1:t_time, function(h)
  paste0('time_', h)))


control_mean = h_t_individual_control[22,]

treat_mean = h_t_individual_treat[22,]

















X_check = data.frame(hazards=round(h_t,3), X_new)
#DT::datatable(X_check)


# hazard ratios over ALL times

h_t_individual = matrix(NA, nrow=n, ncol=t_time)

for (j in 1:n){
  for (q in 1:t_time){
    h_t_individual[j,q] =
      (pi_w[j]*dlnorm(t_val[q], Q1[j], Q1_sigma[j]))/
      ((1-plnorm(t_val[q],Q1[j],Q1_sigma[j]))*pi_w[j] + ((1-plnorm(t_val[q], Q0[j], Q0_sigma[j]))*(1-pi_w[j])))
  }
}
colnames(h_t_individual) = unlist(lapply(1:t_time, function(h)
  paste0('time_', h)))
plot2 = data.frame(h_t_individual) %>%
  tidyr::pivot_longer(cols =  starts_with('time'),
                      names_to = 'time' ) %>%
  mutate(time = rep(seq(from=1,to=t_max, length.out=t_time),n),
         person = unlist(lapply(1:n, function(j)rep(j, t_time))),
         RD = unlist(lapply(1:n, function(j)rep(X_new$Department.Research_Development[j],
                                                t_time))))%>%
  ggplot(aes(x=time,y=value, group=person))+#,color=as.factor(RD)))+
  geom_line( alpha=0.10,linewidth=0.08, color='#012296')+
  scale_color_manual(values=c('#A3A7D2','#8b1a1a'))+
  stat_summary(aes(x=time,
                   y=value, group=F),fun = "mean",
               colour = '#d47c17',
               lwd=2,alpha=1, geom = "line", lty=1)+

  theme_minimal()
colnames(S_t) = unlist(lapply(1:t_time, function(h)
  paste0('time_', h)))
plot1 = data.frame(S_t) %>%
  tidyr::pivot_longer(cols =  starts_with('time'),
                      names_to = 'time' ) %>%
  mutate(time = rep(seq(from=1,to=t_max, length.out=t_time),n),
         person = unlist(lapply(1:n, function(j)rep(j, t_time))),
         RD = unlist(lapply(1:n, function(j)rep(X_new$Department.Research_Development[j],
                                                t_time))))%>%
  ggplot(aes(x=time,y=value, group=person))+#,color=as.factor(RD)))+
  geom_line( alpha=0.16,linewidth=0.08, color='#012296')+
  scale_color_manual(values=c('#A3A7D2','#8b1a1a'))+
  stat_summary(aes(x=time,
                   y=value, group=F),fun = "mean",
               colour = '#d47c17',
               lwd=2,alpha=1, geom = "line", lty=1)+
xlab('Years at company')+ylab('Probability employee lasts until t years')+
  theme_minimal(base_size=18)
#plot2
##plot1
#### summarize max and point


# hazard ratios over ALL times
# Make the grid more fine

h_t_individual = matrix(NA, nrow=n, ncol=t_time)
t_val = seq(from=min(t),to=t_max, length.out = t_time)

for (j in 1:n){
  for (q in 1:t_time){
    h_t_individual[j,q] =
     ( dlnorm(t_val[q], Q1[j], Q1_sigma[j])*pi_w[j])/
      ((1-plnorm(t_val[q],Q1[j],Q1_sigma[j]))*pi_w[j] +
         ((1-plnorm(t_val[q], Q0[j], Q0_sigma[j]))*(1-pi_w[j])))
  }
}
colnames(h_t_individual) = unlist(lapply(1:t_time, function(h)
  paste0('time_', h)))
h_t_max = sapply(1:n, function(k)max(h_t_individual[k,]))
h_t_max_time = sapply(1:n, function(k)t_val[which.max(h_t_individual[k,])])
p2 = data.frame(time=h_t_max_time, max_val = h_t_max) %>%
  ggplot(aes(x=time, y=max_val))+geom_point(color='#073d6d', size=1, alpha=0.9)+
  #geom_density_2d(col='#55Ad89', lwd=.5, alpha=0.8)+
  theme_minimal()

#ggExtra::ggMarginal(p2, type='histogram', fill='#55AD89')


colnames(S_t) = unlist(lapply(1:ncol(S_t), function(k)paste0('time_',k)))

dim_St = t_time#$ncol(S_t)

colnames(h_t_individual) = unlist(lapply(1:ncol(h_t_individual), function(k)paste0('time_',k)))
h_t = as.data.frame(h_t_individual)
S_t = as.data.frame(S_t) %>%
  tidyr::pivot_longer(cols=starts_with('time'),
                      names_to = 'time')
h_t = as.data.frame(h_t) %>%
  tidyr::pivot_longer(cols=starts_with('time'),
                      names_to = 'time')

S_t$time = rep(seq(from=1, to=t_time),n)
h_t$time = rep(seq(from=1, to=t_time),n)
S_t$person = unlist(lapply(1:n,
                           function(q)rep(paste0('Person ', q), dim_St)))

S_t$job_satisfaction = unlist(lapply(1:n,
                           function(q)rep(X$JobSatisfaction[q], dim_St)))
S_t$hazard = h_t$value
colnames(S_t) = c('Time', 'Survival probability', 'Person','Job satisfaction', 'hazard')
S_t_sub = S_t %>%
  group_by(`Job satisfaction`, Time)%>%
  summarize(`Survival probability`=mean(`Survival probability`),
            hazard=mean(hazard), .groups='drop')


S_t_sub %>%
  group_by( `Job satisfaction`) %>%
  na.omit(.) %>%
  summarize(across(c(`Survival probability`, hazard), list), .groups='drop') %>%

  reactable(
    .,
    theme = fivethirtyeight(centered = TRUE),
    compact = TRUE,
    columns = list(

      `Job satisfaction` = colDef(maxWidth = 122),

      `Survival probability` = colDef(
        cell = react_sparkline(
          data = .,
          height = 100,
          line_color = '#073d6d',
          line_width=3,
          min_value=0,
          max_value=1,
          decimals=2,
          labels = c('first', 'last'),
          label_size = '1.1em',
          point_size=2.5,
          highlight_points = highlight_points(first = '#7c1d1d',
                                              last = '#7c1d1d'),
          margin = reactablefmtr::margin(t=15,r=38,b=15,l=38),
          tooltip_type = 2, show_area=T
        )
      ),
      hazard = colDef(
        cell = react_sparkline(
          data = .,
          height = 100,
          line_color = '#073d6d',
          line_width=3,
          min_value=0,
        #  max_value=1,
          decimals = 2,
          labels = c('first', 'last'),
          label_size = '1.1em',
          point_size=2.5,
          highlight_points = highlight_points(first = '#7c1d1d',
                                              last = '#7c1d1d',
                                              max = '#7c1d1d'),
          margin = reactablefmtr::margin(t=15,r=42,b=15,l=42),
          tooltip_type = 2, show_area=T
        )
      )
    )
  )
Code for a survival analysis using BART.
S_t[1:800,] %>%
  group_by( Person, `Job satisfaction`) %>%
  na.omit(.) %>%
  summarize(across(c(`Survival probability`, hazard), list), .groups='drop') %>%

  reactable(
    .,
    theme = fivethirtyeight(centered = TRUE),
    compact = TRUE,
    columns = list(
      Person = colDef(maxWidth = 125),
      `Job satisfaction` = colDef(maxWidth = 122),

      `Survival probability` = colDef(
        cell = react_sparkline(
          data = .,
          height = 100,
          line_color = '#073d6d',
          line_width=3,
          min_value=0,
          max_value=1,
          decimals=2,
          labels = c('first', 'last'),
          label_size = '1.1em',
          point_size=2.5,
          highlight_points = highlight_points(first = '#7c1d1d',
                                              last = '#7c1d1d'),
          margin = reactablefmtr::margin(t=15,r=38,b=15,l=38),
          tooltip_type = 2, show_area=T
        )
      ),
      hazard = colDef(
        cell = react_sparkline(
          data = .,
          height = 100,
          line_color = '#073d6d',
          line_width=3,
          min_value=0,
          #max_value=1,
          decimals = 2,
          labels = c('first', 'last'),
          label_size = '1.1em',
          point_size=2.5,
          highlight_points = highlight_points(first = '#7c1d1d',
                                              last = '#7c1d1d',
                                              max = '#7c1d1d'),
          margin = reactablefmtr::margin(t=15,r=42,b=15,l=42),
          tooltip_type = 2, show_area=T
        )
      )
    )
  )