using DeepPumas
using StableRNGs
using CairoMakie
using Serialization
using Latexify
using PumasPlots
set_mlp_backend(:staticflux)
Covariate Modelling in Nonlinear Mixed Effects Models using DeepPumas
1 Learning objectives
In this tutorial, you will learn to:
- Use DeepPumas to identify the covariate model in a nonlinear mixed effects (NLME) model.
- Augment an existing NLME model with a covariate model that uses a neural network.
- Perform hyperparameter optimization to find the best regularization strength for the neural network.
- Re-fit the variance parameters in the augmented model to reduce the variance of the random effects after incorporating the covariate model.
Note that we will use a neural network as part of the dynamics as well in this tutorial. However, the covariate modeling approach we present here can be used with any NLME model, including those that do not use neural networks in their structure.
2 Setup
First, we need to load the required packages. We also set the backend for the MLPDomain
to use :staticflux
.
3 Synthetic data generation
We will generate synthetic data using an indirect response (IDR) model with complex covariate effects. Below, we define the data-generating model.
= @model begin
datamodel @param begin
∈ RealDomain(; lower = 0, init = 0.5)
tvKa ∈ RealDomain(; lower = 0)
tvCL ∈ RealDomain(; lower = 0)
tvVc ∈ RealDomain(; lower = 0, init = 0.9)
tvSmax ∈ RealDomain(; lower = 0, init = 1.5)
tvn ∈ RealDomain(; lower = 0, init = 0.2)
tvSC50 ∈ RealDomain(; lower = 0, init = 1.2)
tvKout ∈ PDiagDomain(; init = fill(0.05, 5))
Ω ∈ RealDomain(; lower = 0, init = 5e-2)
σ end
@random begin
~ MvNormal(Ω)
η end
@covariates R_eq c1 c2 c3 c4 c5 c6
@pre begin
= tvSmax * exp(η[1]) + 3 * c1 / (12.0 + c1)
Smax = tvSC50 * exp(η[2] + 0.2 * (c2 / 20)^0.75)
SC50 = tvKa * exp(η[3] + 0.3 * c3 * c4)
Ka = tvVc * exp(η[4] + 0.3 * c3)
Vc = tvKout * exp(η[5] + 0.3 * c5 / (c6 + c5))
Kout = R_eq * Kout
Kin = tvCL
CL = tvn
n end
@init begin
= Kin / Kout
Response end
@vars begin
:= max(Central / Vc, 0.0)
cp := Smax * cp^n / (SC50^n + cp^n)
EFF end
@dynamics begin
' = -Ka * Depot
Depot' = Ka * Depot - (CL / Vc) * Central
Central' = Kin * (1 + EFF) - Kout * Response
Responseend
@derived begin
~ @. Normal(Response, σ)
Outcome end
end
We then define the population parameter values we will use to generate the synthetic data.
= (;
p_data = 0.5,
tvKa = 1.0,
tvCL = 1.0,
tvVc = 1.2,
tvSmax = 1.5,
tvn = 0.02,
tvSC50 = 2.2,
tvKout = Diagonal(fill(0.05, 5)),
Ω = 0.1,
σ )
Next, we generate a synthetic population of 1020 subjects. Each subject receives a dose of 0.5 units at time 0, and another similar dose 8 hours later. We then define a distribution for each covariate and use the synthetic_data
function to generate the population. Each subject ia observed at times 0, 2, …, 24 hours.
= DosageRegimen(0.5, ii = 8, addl = 1)
dr = synthetic_data(
pop
datamodel,
dr,
p_data;= (;
covariates = Gamma(50, 1 / (50)),
R_eq = Gamma(5, 2),
c1 = Gamma(21, 1),
c2 = Normal(),
c3 = Normal(),
c4 = Gamma(11, 1),
c5 = Gamma(11, 1),
c6
),= 1020,
nsubj = StableRNG(123),
rng = 0:2:24,
obstimes )
To visualize the distribution of the covariates, we use the covariates_dist
function.
covariates_dist(pop)
Finally, we create 2 training sets and one test set from the generated subjects. The first training set is made of the first 50 subjects. The second training set is made of the first 1000 subjects, including the first 50 subjects. And the test set is made of the remaining subjects not included in either training sets.
= pop[1:50]
trainpop_small = pop[1:1000]
trainpop_large = pop[1001:end] testpop
To visualize the test populations, we use the predict
and plotgrid
function to plot the data and predictions using the true data-generating model. This is the best possible prediction we can achieve.
= predict(datamodel, testpop, p_data; obstimes = 0:0.1:24);
pred_datamodel plotgrid(pred_datamodel)
4 Dynamics model identification using DeepNLME
Here, we define a model where the pharmacodynamic (PD) model is entirely determined by a neural network. At this point, no covariates have been included yet.
= @model begin
model @param begin
∈ MLPDomain(5, 6, 5, (1, identity); reg = L2(0.1))
NN ∈ RealDomain(; lower = 0)
tvKa ∈ RealDomain(; lower = 0)
tvCL ∈ RealDomain(; lower = 0)
tvVc ∈ RealDomain(; lower = 0)
tvR₀ ∈ RealDomain(; lower = 0)
ωR₀ ∈ PDiagDomain(2)
Ω ∈ PDiagDomain(3)
Ω_nn ∈ RealDomain(; lower = 0)
σ end
@random begin
~ MvNormal(Ω)
η ~ MvNormal(Ω_nn)
η_nn end
@pre begin
= tvKa * exp(η[1])
Ka = tvVc * exp(η[2])
Vc = tvCL
CL = tvR₀ * exp(10 * ωR₀ * η_nn[1])
Response₀ = fix(NN, η_nn)
iNN end
@init begin
= Response₀
Response end
@dynamics begin
' = -Ka * Depot
Depot' = Ka * Depot - (CL / Vc) * Central
Central' = iNN(Central / Vc, Response)[1]
Responseend
@derived begin
~ @. Normal(Response, σ)
Outcome end
end
Next, we fit the model to the training data. The fit
function is used to estimate the parameters of the model using the training population trainpop_small
. The initial parameter estimates are provided by init_params(model)
, and we use the MAP(FOCE())
method for fitting. We run the fit up to 300 iterations.
= fit(
fpm
model,
trainpop_small,init_params(model),
MAP(FOCE());
= (; iterations = 300),
optim_options )
Next, we calculate the individual and population predictions using the fitted model on the test data using predict
and visualize the predictions using plotgrid
.
= predict(fpm, testpop; obstimes = 0:0.1:24);
pred plotgrid(pred)
The model has succeeded in discovering the dynamical model if the individual predictions match the observations of the test population well.
5 Covariate model identification using DeepNLME
So far, all the between-subject variability (BSV) in the data was captured by the random effects. Now, we will “augment” that covariate-free NLME model with a machine learning (ML) covariate model. Capturing some of the BSV with a covariate model reduces the amount of BSV that needs to be captured by the random effects, which in turn will improve the model’s predictive accuracy at baseline (aka population prediction).
In particular, we will use a neural network to predict the random effects’ posterior distribution from the covariates. In DeepPumas, you can call the preprocess
function on any fitted Pumas model to generate a target for the ML fitting. The target consists of:
- The covariates standardized, and
- The empirical Bayes estimates (EBEs) of the random effects, scaled appropriately to weigh well-identified random effects higher than poorly identified ones.
= preprocess(fpm) target
FitTarget(ℝ⁷→ℝ⁵, 50 datapoints, (:R_eq, :c1, :c2, :c3, :c4, :c5, :c6) → (:η, :η_nn))
We can then create our neural network model. We first switch to the :simplechains
backend for MLPDomain
because it is faster for covariate modelling. The MLPDomain
function is then used to define a multi-layer perceptron (MLP) with as many inputs as the number of covariates, 7 hidden units in each of the three hidden layers, and as many outputs as the number of random effect we fitting towards. The reg
argument specifies a regularization term to prevent over-fitting.
set_mlp_backend(:simplechains)
= MLPDomain(numinputs(target), 7, 7, 7, (numoutputs(target), identity); reg = L2(1.0)) nn
We can then fit the neural network to the target using the fit
function.
= fit(nn, target) fnn
The fitted neural network can be augmented to the original model using the augment
function. This will create a new model that includes the neural network as a covariate model.
= augment(fpm, fnn) augmented_fpm
After incorporating the covariate model, the random effects in the new model will now represent the residual BSV in the individual parameters that is not accounted for by the covariates.
Using the new model, we can calculate the individual and population predictions on the test data using predict
and visualize the predictions using plotgrid
. For comparison, we overlay 3 sets of population predictions:
- The predictions from the augmented model with covariates,
- The predictions from the true data-generating model, and
- The predictions from the original model without covariates.
=
pred_augment predict(augmented_fpm.model, testpop, coef(augmented_fpm); obstimes = 0:0.1:24);
plotgrid(
pred_datamodel;= false,
ipred = (; color = (:black, 0.4), label = "Best possible pred"),
pred
)plotgrid!(pred; ipred = false, pred = (; color = (:red, 0.2), label = "No covariate pred"))
plotgrid!(pred_augment; ipred = false, pred = (; linestyle = :dash))
To get a quantitative measure of how close our augmented model’s population predictions are to the population predictions of the data-generating model, we define the following function to calculate the mean absolute error (MAE) between the 2 sets of population predictions.
function mae_preds(pred1, pred2)
= mapreduce(hcat, pred1, pred2) do p1, p2
resid .- p2.pred.Outcome
p1.pred.Outcome end
return mean(abs, resid)
end
We use the function to calculate the MAE in the population predictions across the entire test population.
mae_preds(pred_datamodel, pred_augment)
0.2541189148897362
We compare this to the MAE between the population predictions of the data-generating model and the original model without covariates.
mae_preds(pred_datamodel, pred)
0.33088606309647844
Note how the model has improved its population predictions compared to the original model without covariates.
6 Hyperparameter optimization
In our covariate model above, we used an L2 regularization with a regularization strength of 1.0. But we don’t know if this is the best regularization strength. We can use hyperparameter optimization (HPO) to automatically search for the best regularization strength. In DeepPumas, instead of calling fit
on the neural network and target, we can call hyperopt
to perform HPO of the regularization strength \(\lambda\). hyperopt
takes as input an optional algorithm for the hyperparameter optimization, e.g. GridSearch
. Below, we use hyperopt
to find the best value of \(\lambda\) for L2 regularization by searching through a grid of 40 evenly spaced values in log scale, ranging from \(10^{-7}\) to \(10^{5}\). We use 10-fold cross-validation to evaluate the performance of the model with different hyperparameter values. The best hyperparameter value is then used to fit the model again using the whole training dataset.
= hyperopt(
ho
nn,
target,GridSearch(
= (1e-7, 1e5));
(; lambda = KFold(K = 10),
criteria = :log10,
scale = 40,
resolution
),
) ho.best_hyperparameters
(lambda = 4.923882631706739,)
Two interesting cases can arise from the hyperparameter optimization:
- If the best \(\lambda\) value is very large, this will force all the NN parameters to be 0, indicating that the covariates are not useful at predicting the random effects. In this case, we can simply use the original model without covariates.
- If the best \(\lambda\) is the maximum (minimum) value in the grid search and that value is not too large (small), it indicates that the range of values we searched over was not wide enough, and we should increase the range of search values for \(\lambda\).
Using the best hyperparameter value, we can augment the original model with the new neural network and plot the predictions again.
= augment(fpm, ho)
augmented_fpm =
pred_augment_ho predict(augmented_fpm.model, testpop, coef(augmented_fpm); obstimes = 0:0.1:24);
plotgrid(
pred_datamodel;= false,
ipred = (; color = (:black, 0.4), label = "Best possible pred"),
pred
)plotgrid!(pred; ipred = false, pred = (; color = (:red, 0.2), label = "No covariate pred"))
plotgrid!(pred_augment_ho; ipred = false, pred = (; linestyle = :dash))
We then calculate the MAE between the population predictions of the data-generating model and the augmented model with hyperparameter optimization.
mae_preds(pred_datamodel, pred_augment_ho)
0.20026065258097048
Note how the model has improved its population predictions compared to the original model without covariates and compared to the model with a fixed regularization strength.
7 Data size
Training covariate models well requires more data than fitting the neural networks embedded in dynamical systems. With UDEs, every observation is a data point. With covariate models, every subject is a data point. We’ve managed to improve our model using only 50 subjects, but let’s try using data from 1000 patients instead.
To create the fitting target, we use another call signature (aka method) of the preprocess
function that takes the model, a population, population parameters, and an algorithm. This will create a target that includes the covariates and the EBEs of the random effects scaled appropriately.
= preprocess(model, trainpop_large, coef(fpm), FOCE()) target_large
FitTarget(ℝ⁷→ℝ⁵, 1000 datapoints, (:R_eq, :c1, :c2, :c3, :c4, :c5, :c6) → (:η, :η_nn))
We then do HPO again using the hyperopt
function, but this time we use the larger target. Finally, we augment the original model with the new neural network and plot the predictions again.
= hyperopt(
fnn_large
nn,
target_large,GridSearch(
= (1e-7, 1e5));
(; lambda = KFold(K = 10),
criteria = :log10,
scale = 40,
resolution
),
) fnn_large.best_hyperparameters
= augment(fpm, fnn_large) augmented_fpm_large
We then make the same predictions as before.
= predict(
pred_augment_large
augmented_fpm_large.model,
testpop,coef(augmented_fpm_large);
= 0:0.1:24,
obstimes
);plotgrid(
pred_datamodel;= false,
ipred = (; color = (:black, 0.4), label = "Best possible pred"),
pred
)plotgrid!(pred; ipred = false, pred = (; color = (:red, 0.2), label = "No covariate pred"))
plotgrid!(pred_augment_large; ipred = false, pred = (; linestyle = :dash))
Then, we calculate the MAE between the population predictions of the data-generating model and the augmented model with hyperparameter optimization using the larger target.
mae_preds(pred_datamodel, pred_augment_large)
0.10991770173083092
Notice how the model has improved its population predictions compared to the model trained on 50 subjects.
8 Re-fitting the augmented model
After augmenting the model, we could fit everything in concert. We can start the fit from our sequentially attained parameter values. However, re-fitting all the parameters including the covariate model’s neural network is computationally expensive, and often unnecessary in our experience.
Instead, even if we don’t re-fit every parameter, we should still fit the variance parameters of the random effects, such that we don’t overestimate the residual BSV now that we captured some of the BSV using the covariate model. We can achieve this using the constantcoef
argument of the fit
function. This allows us to specify which parameters we want to keep fixed during the fit. In this case, we will keep all parameters fixed except for the variance parameters of the random effects. We fit the model for 100 iterations.
= keys(coef(augmented_fpm_large))
all_param_names = (:Ω, :Ω_nn)
refit_param_names = Tuple(setdiff(all_param_names, refit_param_names))
constantcoef
= fit(
fpm_refit_Ω
augmented_fpm_large.model,
trainpop_large,coef(augmented_fpm_large),
MAP(FOCE());
constantcoef,= (; iterations = 200),
optim_options )
We can then compare the variance parameters of the random effects before and after the re-fit. The coef
function returns the coefficients of the fitted model, and we can access the Ω
and Ω_nn
parameters to compare their diagonals.
diag(coef(fpm_refit_Ω).Ω) ./ diag(coef(augmented_fpm).Ω)
2-element Vector{Float64}:
1.2299825516606964
0.6383034477191067
diag(coef(fpm_refit_Ω).Ω_nn) ./ diag(coef(augmented_fpm).Ω_nn)
3-element Vector{Float64}:
0.22399834817371
0.22801980795937912
0.4252922254309569
Note the reduced variance for most of the random effects after the re-fit. This indicates that the covariate model has captured some of the BSV in the original model.
To show the importance of the re-fit, we can simulate new subjects from the augmented model and compare the simulations before and after the re-fit using a visual predictive check (VPC).
First, we show the VPC of the augmented model before the re-fit.
=
augmented_vpc_res vpc(augmented_fpm_large, samples = 50, observations = [:Outcome], bandwidth = 6)
vpc_plot(
augmented_vpc_res;= true,
simquantile_medians = (
axis = "Time (hours)",
xlabel = "Outcome",
ylabel = "VPC of Augmented Model Before Re-fit",
title
), )
Next, we show the VPC of the augmented model after the re-fit.
= vpc(fpm_refit_Ω, samples = 50, observations = [:Outcome], bandwidth = 6) refit_vpc_res
vpc_plot(
refit_vpc_res;= true,
simquantile_medians = (
axis = "Time (hours)",
xlabel = "Outcome",
ylabel = "VPC of Augmented Model After Omega Re-fit",
title
), )
Note the overall reduced variance in the simulations after the re-fit. Note that we only use 50 samples per subject in the above VPCs because we have 1000 subjects in the population.
9 Conclusion
In this tutorial, we have learned how to use DeepPumas to identify the covariate model in a nonlinear mixed effects model. We have seen how to generate synthetic data, fit a neural-embedded NLME model, and augment the model with a covariate model using a neural network. We also learned how to perform hyperparameter optimization to find the best regularization strength for the neural network. Finally, we saw how to re-fit the variance parameters in the augmented model to reduce the variance of the random effects after incorporating the covariate model.