using Random
using Distributions
using DeepPumas
using CairoMakie
set_mlp_backend(:staticflux)
Random.seed!(1234)
Mixed-effects neural networks
1 Learning objectives
By the end of this tutorial, you will be able to:
- Understand the value of a mixed-effects neural network (MeNet) model over a neural network function of time.
- Define and fit a MeNet model using DeepPumas.
- Explore the impact of data quality and quantity on model performance.
- Analyze the effect of the number of random effects on the model’s ability to capture between-subject variability.
2 Setup
First, we need to load the packages we will use in this tutorial. We set the MLP backend to :staticflux
and set the random seed for reproducibility.
3 Synthetic data and true model
We will generate synthetic data to illustrate the use of mixed-effects neural networks. The data will be generated from a model that has a time-dependent outcome, Y, which is influenced by two parameters, c1 and c2, that vary between subjects.
= @model begin
datamodel_me @param σ ∈ RealDomain(; lower = 0.0, init = 0.05)
@random begin
~ Uniform(0.5, 1.5)
c1 ~ Uniform(-1, 0)
c2 end
@pre X = c1 * t / (t + exp10(c2))
@derived Y ~ @. Normal(X, σ)
end
We will simulate data from this model for 112 subjects, with each subject having observations at 21 time points between 0 and 1. The first 100 subjects will be used for training, and the remaining 12 subjects will be used for testing. We visualize the first 12 subjects in the training set using plotgrid
.
=
sims simobs(datamodel_me, [Subject(; id) for id = 1:112], (; σ = 0.05); obstimes = 0:0.05:1)
= Subject.(sims[1:100])
trainpop_me = Subject.(sims[101:end])
testpop_me
plotgrid(trainpop_me[1:12])
4 Neural network function of time
Let’s begin by fitting a simple neural network function of time to the training data. This neural network describes Y as a function of time but does not account for the between-subject variability in the parameters.
The neural network is defined using MLPDomain
, which specifies the architecture of the neural network. We will use a 4-layer MLP with 2 hidden layers, each having 4 nodes. The activation function of the output layer is set to identity
, meaning it will output values between -Inf and Inf, which is suitable for regression tasks. We also apply L2 regularization to the neural network parameters to lower the risk of overfitting. The default activation function for the input and hidden layers is tanh
.
= @model begin
model_t @param begin
∈ MLPDomain(1, 4, 4, (1, identity); reg = L2(1.0))
NN ∈ RealDomain(; lower = 0.0)
σ end
@pre X = NN(t)[1]
@derived Y ~ Normal.(X, σ)
end
Next, we fit this model to the training data. We use the fit
function, initializing the parameters with init_params(model_t)
. The fitting is done using the MAP estimation method with a NaivePooled approach. MAP ensures the regularization of the neural network parameters is added to the objective function value during fitting. NaivePooled
is used because the model does not have any random effects.
= fit(model_t, trainpop_me, init_params(model_t), MAP(NaivePooled())) fpm_t
We then predict the outcome Y for the training data using the fitted model. The predict
function generates predictions for a range of time points from 0 to 1, with a step size of 0.01.
= predict(fpm_t; obstimes = 0:0.01:1) preds
Finally, we visualize the training data and the prediction for the first subject. Since there are no covariates or random effects in this model, the prediction will be the same for all subjects so we only show one of them.
The scatter
function is used to plot the observed data points, and the lines!
function overlays the predicted values for the first subject in the training data. The x-axis represents time, and the y-axis represents the outcome Y.
= preds[1]
pred = DataFrame(trainpop_me);
df =
fig, ax, plt scatter(df.time, df.Y; label = "Data", axis = (; xlabel = "Time", ylabel = "Y"));
lines!(ax, collect(pred.time), pred.pred.Y; color = Cycled(2), label = "NN prediction")
axislegend(ax; position = :lt)
fig
Notice how the neural network captures the general trend of the data, but it does not account for the individual variability between subjects. The prediction is made for the average observations across all subjects, which does not reflect the true individual trajectories.
5 Mixed-effects neural network
To address the above issue, we will now extend the model to include random effects, allowing the neural network to learn individual-specific functions of time. In the following model, 2 random effects are used as input to the neural network.
= @model begin
model_me @param begin
∈ MLPDomain(3, 6, 6, (1, identity); reg = L2(1.0))
NN ∈ RealDomain(; lower = 0)
σ end
@random η ~ MvNormal(0.01 * I(2))
@pre X = NN(t, η)[1]
@derived Y ~ @. Normal(X, σ)
end
We now fit this mixed-effects neural network (MeNet) model to the training data. The fit
function is used with the MAP(FOCE())
method, which maximizes the product of the prior and marginal likelihood.
= fit(
fpm_me
model_me,
trainpop_me,init_params(model_me),
MAP(FOCE());
= (; iterations = 200),
optim_options )
After the fit, we predict the outcome Y for the training data using the fitted model. The predict
function generates predictions for a range of time points from 0 to 1, with a step size of 0.01. We plot the first 12 subjects’ predictions to visualize the model’s performance on the training data.
= predict(fpm_me; obstimes = 0:0.01:1)[1:12]
pred_train plotgrid(pred_train)
We predict and plot the predictions for the test data as well.
= predict(fpm_me, testpop_me[1:12]; obstimes = 0:0.01:1)
pred_test plotgrid(pred_test; ylabel = "Y (Test data)")
Notice how the mixed-effects neural network captures the individual variability in the data, accurately fitting the individual time profile of each subject in the training and test data.
6 Data quality and quantity
The quality of the fits here depends on a few different things. Among these are:
- The number of training subjects
- The number of observations per subject
- The noisiness of the data
- The regularization of your embedded neural network
You may not need many patients to train on if your data is good. But if the data quality is a bit off, then data quantity might compensate.
Let’s lower the training data quality by increasing the noise, and decrease the training data quantity by reducing the number of subjects to 6 and the number of observations per subject to 4, in the training data.
= simobs(
sims_new
datamodel_me,Subject(; id) for id = 1:6], # Change the number of patients
[= 0.1); # Tune the additive noise
(; σ = 0:0.3:1, # Modify the observation times
obstimes
)= Subject.(sims_new)
traindata_new plotgrid(traindata_new)
We fit the same model the lower quality training data.
= fit(
fpm_me_2
model_me,
traindata_new,sample_params(model_me),
MAP(FOCE());
= (; iterations = 300, time_limit = 3 * 60),
optim_options )
We then plot the predictions for the training data.
= predict(fpm_me_2; obstimes = 0:0.01:1)[1:min(12, end)]
pred_train plotgrid(pred_train[1:min(12, end)]; ylabel = "Y (training data)")
And we visualize the predictions for the old densely sampled test data.
= predict(model_me, testpop_me, coef(fpm_me_2); obstimes = 0:0.005:1)
pred_test plotgrid(pred_test; ylabel = "Y (Test data)")
Notice how the model is not able to perfectly fit the training data anymore, and the predictions for the test data are also not as accurate as before. The model struggles to capture the individual variability in the data due to the reduced number of subjects and observations, as well as the increased noise.
You can download the Quarto tutorial, edit the code above and test different combinations of data quality and quantity to see how they affect the model’s performance.
7 How many random effects?
The number of random effects in a mixed-effects neural network model is an important factor that influences the model’s ability to capture between-subject variability. The number of random effects should ideally match the dimensionality of the outcome heterogeneity in the data. This is an abstract concept if we don’t know the true data-generating model, but we can illustrate it with our synthetic data.
In our synthetic data, the between-subject variability arises from two parameters, c1 and c2. These parameters are linearly independent and influence the patient trajectories in distinct ways. A change in c1 cannot be offset by a change in c2, meaning there are two separate dimensions of between-subject variability in the data.
The model_me
includes a two-dimensional vector of independent random effects, allowing it to capture variability along two dimensions. During the fitting process, the neural network determines how to best utilize these random effects to model the data. While the way the model uses the random effects may not exactly match the data-generating process, the model’s dimensionality aligns with the data’s variability. With sufficient high-quality data, the model should be able to perfectly predict individual trajectories.
But what happens if the model has fewer dimensions of between-subject variability than the data? We can explore this by reducing the number of random effects provided to the neural network in model_me
.
The model is now too simple to perfectly fit the data. We will train this model on high-quality data to observe how the fit fails and in what specific way the model struggles to capture the variability.
= @model begin
model_me2 @param begin
∈ MLPDomain(2, 6, 6, (1, identity); reg = L2(1.0)) # We now only have 2 inputs as opposed to 3 in model_me
NN ∈ RealDomain(; lower = 0.0)
σ end
@random η ~ Normal(0, 1)
@pre X = NN(t, η)[1]
@derived Y ~ Normal.(X, σ)
end
=
sims_great simobs(datamodel_me, [Subject(; id) for id = 1:100], (; σ = 0.01); obstimes = 0:0.05:1)
= Subject.(sims_great) great_data
We plot the first 12 subjects in the training data to visualize the high-quality data.
plotgrid(great_data[1:12])
We then fit the simple model model_me2
to the data and visualize the first 24 subjects’ predictions.
= fit(
fpm_me2
model_me2,
great_data,sample_params(model_me2),
MAP(FOCE());
= (; time_limit = 3 * 60),
optim_options )
= predict(fpm_me2; obstimes = 0:0.01:1)[1:24]
pred_train plotgrid(pred_train; ylabel = "Y (training data)")
The fits appear to be quite good! But let’s take a closer look at the predictions of 6 subjects:
= predict(fpm_me2; obstimes = 0:0.01:1)[[2, 4, 6, 7, 10, 11]]
pred_train plotgrid(pred_train; ylabel = "Y (training data)")
Note that the model should not be able to make a perfect fit here, as it has fewer random effects than the data-generating process. Looking closely at the 6 selected subjects, the fit is clearly biased, often overestimating the beginning and underestimating the end, or the opposite.
The model fit above was able to find a single dimension that explains most, but not all, of the variability in the data. But how did it do this? Note that the 2 random effects in the original data-generating model are only independent in their prior distributions, not their posteriors! The neural network was therefore able to learn a representation of the data using a lower dimension of random effects than the true dimension, much like how dimension reduction works in principal component analysis.
We can visualize exactly how the neural network utilizes this single random effect to define the individual function of time describing the response Y.
= Figure()
fig = Axis(fig[1, 1]; xlabel = "t", ylabel = "NN(t, η)")
ax = map(x -> x.η[1], empirical_bayes(fpm_me2))
ηs = 0:0.01:1
trange = coef(fpm_me2).NN
nn = (minimum(ηs), maximum(ηs))
colorrange for η in ηs
lines!(ax, trange, first.(nn.(trange, η)); color = η, colormap = :Spectral, colorrange)
end
Colorbar(fig[1, 2]; colorrange, colormap = :Spectral, label = "η")
fig