Discovery of complex feedback in a neural-embedded Friberg model using real world data

Authors

Domas Linkevičius

Niklas Korsbo

The “Getting started with DeepNLME” tutorial demonstrated how to use DeepPumas to model complex, missing dynamical terms in a pharmacokinetic/pharmacodynamic (PK/PD) model. However, the data in that tutorial was generated from a relatively simple model, which could have been effectively fitted using traditional methods, rather than the data-driven DeepPumas approach. In this tutorial we employ a more complex model and real-world data to illustrate that DeepPumas outperforms traditional PK/PD models in a more realistic scenario. We discussing the potential points of improvement of a classical PK/PD model and show how they can be address via DeepPumas. We will train both a classical model and a DeepNLME model, comparing their performance on a held-out data set. Finally, we will analyze the functional form of the feedback term in the classical model and contrast it with the term learned through the data-driven DeepPumas approach.

In this tutorial we utilize publicly available data from Jost et al. (2019), which includes white blood cell (WBC) measurements from patients with acute myeloid leukaemia treated with cytarabine. Significant portions of this tutorial are adapted from Martensen et al. (2024), who used DeepPumas and symbolic regression to discover symbolic expressions for the feedback term that mature white blood cells exert on proliferating cells.

We begin by loading the packages necessary for the tutorial.

using DeepPumas
using CairoMakie
using PharmaDatasets
set_theme!(deep_light())

1 Loading of the data

We provide a brief overview of the dataset used in this tutorial; for a detailed description see Jost et al. (2019). The data preparation steps closely follow those of Martensen et al. (2024).

The complete dataset comprises measurements of white blood cells from 23 patients over time, with each patient contributing a varying number of measurements (ranging from 22 to 94). Additionally, the dataset includes the amounts of cytarabine administered and the times at which it was administered. We load the data and split it into two subsets: trainpop, which will be used for model training, and testpop, which will be reserved for evaluating model performance after training.

df = dataset("journal.pone.0204540.csv")

trainpop = read_pumas(
    df[df.TRAINING, :],
    observations = [:WBC];
    id = :ID,
    time = :TIME,
    rate = :RATE,
    covariates = [:x0],
    evid = :EVID,
    amt = :AMT,
    cmt = :CMT,
    event_data = true,
);

testpop = read_pumas(
    df[.!df.TRAINING, :],
    observations = [:WBC];
    id = :ID,
    time = :TIME,
    rate = :RATE,
    covariates = [:x0],
    evid = :EVID,
    amt = :AMT,
    cmt = :CMT,
    event_data = true,
);

2 Classical PK/PD model definition

Our classical PK/PD model is from Jost et al. (2019) whose model is based on the work of Friberg et al. (2002). It consists of five state variables, two related to PK - central compartment \(x_c\) and peripheral compartment \(x_p\) - and three related to PD - proliferating cells \(x_{pr}\), one (determined to be the optimal number in Martensen et al., 2024) transition compartment \(x_{tr}\) and mature white blood cells \(x_{ma}\). It contains seven free parameters, five of which parameterize the feedback from \(x_{ma}\) onto \(x_{pr}\):

  • \(tvKtr\), the typical transition value
  • \(tvSlope\), the typical value of the slope of the drug influence on \(x_{pr}\)
  • \(tvB\), the typical baseline white blood cell count
  • \(tvB0\), the typical initial white blood cell count
  • \(tv\gamma\), the typical power parameter of the influence of \(x_{ma}\) on \(x_{pr}\)

The other two free parameters are \(\Omega\), the covariance matrix of the random effect prior, and additive observational noise standard deviation \(\sigma_{add}\).

classical_model = @model begin
    @param begin
        tvKtr  RealDomain(; lower = 0.01)
        tvSlope  RealDomain(; lower = 0.01)
        tvB  RealDomain(; lower = 0.01)
        tvB0  RealDomain(; lower = 0.01)
        tvγ  RealDomain(; lower = 0.01)
        Ω  PDiagDomain(; init = fill(0.1, 5))
        σ_add  RealDomain(; lower = 1e-3, init = 10)
    end

    @random begin
        η ~ MvNormal(Ω)
    end

    @pre begin
        # Model constants resulting from PK fit 
        """
        Clearance (L/d)
        """
        CL = 3.4993
        """
        Volume of Distribution (L)
        """
        Vc = 0.031415
        """
        Volume of distribution (L)
        """
        Vp = 0.0088778
        """
        Distribution clearance (L/d)
        """
        Q = 0.14848
        # Constants
        """
        Molecular mass of cyterabin
        """
        MM_cyt = 243.217
        """
        Effect based on molecular mass of cyterabin
        """
        c_V = inv(Vc * MM_cyt)
        """
        Death rate of mature cells (1/d)
        """
        k_ma = 2.3765

        # Parameters 
        """
        Transition rate (1/d)
        """
        k_tr = tvKtr * exp(η[1])
        """
        Effect for PD (L / μmol)
        """
        slope = tvSlope * exp(η[2])
        """
        Baseline WBC
        """
        B = tvB * exp(η[3])
        """
        Initial condition of WBC
        """
        B0 = tvB0 * exp(η[4])
        """
        Feedback exponent
        """
        γ = tvγ * exp(η[5])
    end

    @dosecontrol begin
        duration = (Central = 3 / 24,)
        bioav = (Central = 1.0011 / 8.0,)
    end

    @init begin
        Central = 0.0
        x_pr = B * k_ma / k_tr
        x_tr = B * k_ma / k_tr
        x_ma = B0
    end

    @vars begin
        E = 1 - slope * log(1 + c_V * Central)
        F_B = (B / (sqrt(x_ma^2) + 1e-1))^abs(γ) ### tricks for stability during optimization
    end

    @dynamics begin
        # PK Model 
        Central' = -(CL + Q) / Vc * Central + Q / Vp * Peripheral
        Peripheral' = Q / Vc * Central - Q / Vp * Peripheral
        # PD Model 
        x_pr' = -k_tr * x_pr + E * k_tr * F_B * x_pr
        x_tr' = k_tr * x_pr - k_tr * x_tr
        x_ma' = k_tr * x_tr - k_ma * x_ma
    end

    @derived begin
        WBC ~ @. Normal(x_ma, σ_add)
    end
end

The term of interest to us is the feedback term \(F_B = (\frac{B}{x_{ma}})^\gamma\) that determines the feedback from \(x_{ma}\) to \(x_{pr}\). We target this term as we think there might be room for improvement because:

  1. It is likely a bit too simplistic. As we will show below on the data from Jost et al. (2019), it is possible to model the data more accurately using a more flexible function of \(x_{ma}\).
  2. The \(\frac{1}{x_{ma}}\) term is ill-behaved at low \(x_{ma}\) which, when raised to the power \(\gamma\), tends to result in numerical issues during parameter optimization.

Having described the classical model, we next define the DeepNLME model.

3 DeepNLME model definition

We address the limitations of the feedback term \(F_B = (\frac{B}{x_{ma}})^\gamma\) by setting \(F_B = NN_\theta(\eta, x_{ma})\), where \(NN_\theta\) is a neural network with parameters \(\theta\). To ensure that the embedded neural network behaves reasonably, we cap it off with a softplus nonlinearity that ensures positive outputs. Otherwise, we keep the model structure exactly the same as the classical model, as we are reasonably certain that it is sufficiently accurate.

deepnlme_model = @model begin
    @param begin
        tvKtr  RealDomain(; lower = 0.01)
        tvSlope  RealDomain(; lower = 0.01)
        tvB  RealDomain(; lower = 0.01)
        tvB0  RealDomain(; lower = 0.01)
        Ω  PDiagDomain(; init = fill(0.1, 6))
        tvF_B  MLPDomain(
            3,
            7,
            7,
            (1, softplus),
            backend = :staticflux,
            bias = true,
            reg = L2(1.65e-1),
        )
        σ_add  RealDomain(; lower = 0.0001, init = 10)
    end

    @random begin
        η ~ MvNormal(Ω)
    end

    @pre begin
        # Model constants resulting from PK fit 
        """
        Clearance (L/d)
        """
        CL = 3.4993
        """
        Volume of Distribution (L)
        """
        Vc = 0.031415
        """
        Volume of distribution (L)
        """
        Vp = 0.0088778
        """
        Distribution clearance (L/d)
        """
        Q = 0.14848
        # Constants
        """
        Molecular mass of cyterabin
        """
        MM_cyt = 243.217
        """
        Effect based on molecular mass of cyterabin
        """
        c_V = inv(Vc * MM_cyt)
        """
        Death rate of mature cells (1/d)
        """
        k_ma = 2.3765

        # Parameters 
        """
        Transition rate (1/d)
        """
        k_tr = tvKtr * exp(η[1])
        """
        Effect for PD (L / μmol)
        """
        slope = tvSlope * exp(η[2])
        """
        Baseline WBC
        """
        B = tvB * exp(η[3])
        """
        Initial condition of WBC
        """
        B0 = tvB0 * exp(η[4])
        """
        Neural feedback
        """
        F_B = only  fix(tvF_B, η[5:6])
    end

    @dosecontrol begin
        duration = (Central = 3 / 24,)
        bioav = (Central = 1.0011 / 8.0,)
    end

    @init begin
        Central = 0.0
        x_pr = B * k_ma / k_tr
        x_tr = B * k_ma / k_tr
        x_ma = B0
    end

    @vars begin
        E = 1 - slope * log(1 + c_V * Central)
    end

    @dynamics begin
        # PK Model 
        Central' = -(CL + Q) / Vc * Central + Q / Vp * Peripheral
        Peripheral' = Q / Vc * Central - Q / Vp * Peripheral
        # PD Model 
        x_pr' = -k_tr * x_pr + E * k_tr * F_B(x_ma / 10.0) * x_pr
        x_tr' = k_tr * x_pr - k_tr * x_tr
        x_ma' = k_tr * x_tr - k_ma * x_ma
    end

    @derived begin
        WBC ~ @. Normal(x_ma, σ_add)
    end
end

With the classical and the neural embedded models defined, we move on to their training.

4 Model training

Given the relatively small trainpop size (18 subjects), the classical model converges in around 3 minutes on a regular machine. The DeepNLME model fit, since the neural network has many more parameters, takes longer, around 20 minutes. We terminate both fits by setting an iteration limit after which the loss decreases negligibly.

fpm_classical = fit(
    classical_model,
    trainpop,
    init_params(classical_model),
    MAP(FOCE()),
    optim_options = (; iterations = 60, show_trace = true),
)

fpm_deepnlme = fit(
    deepnlme_model,
    trainpop,
    init_params(deepnlme_model),
    MAP(FOCE()),
    optim_options = (; iterations = 40, show_trace = true, allow_f_increases = true),
)

We next plot and compare the performance of the trained models on a held out data set.

5 Model comparison

Plotting the predictions of both models on the same axes shows that the DeepNLME model (dashed lines) matches the data overall slightly better than the classical model (solid lines).

test_preds_classical = map(testpop) do i
    predict(fpm_classical, i; obstimes = i.time[1]:0.1:i.time[end])
end
test_preds_deepnlme = map(testpop) do i
    predict(fpm_deepnlme, i; obstimes = i.time[1]:0.1:i.time[end])
end

plotgrid(
    test_preds_classical;
    pred = (; linestyle = :solid, label = "Classical pred"),
    ipred = (; linestyle = :solid, label = "Classical ipred"),
)
plotgrid!(
    test_preds_deepnlme;
    pred = (; linestyle = :dash, label = "DeepNLME pred"),
    ipred = (; linestyle = :dash, label = "DeepNLME ipred"),
)

Finally, we evaluate the generalization performance of both models numerically. For that we use two metrics: marginal loglikelihood and root-mean-square-error (RMSE).

test_LL_classical = loglikelihood(fpm_classical.model, testpop, coef(fpm_classical), FOCE());
test_LL_deepnlme = loglikelihood(fpm_deepnlme.model, testpop, coef(fpm_deepnlme), FOCE());

The classical model marginal loglikelihood (recall that higher likelihood is better) on held out data is -374.1, whereas the DeepNLME marginal loglikelihood is -352.6. However, loglikelihoods can be difficult to interpret in real world terms. Therefore, we calculate the difference between RMSE values from the classical model and the DeepNLME model for the same patient and show the differences as a boxplot.

function rmse(fpm, pop)
    preds = predict(fpm, pop)
    rmses = map(preds) do s
        sqrt(mean((s.ipred.WBC .- s.subject.observations.WBC) .^ 2))
    end
end

test_RMSE_classical = rmse(fpm_classical, testpop)
test_RMSE_deepnlme = rmse(fpm_deepnlme, testpop)

diffs = test_RMSE_classical .- test_RMSE_deepnlme


f = Figure()
ax = Axis(
    f[1, 1],
    title = "Trained model comparison",
    ylabel = "Classical RMSE - DeepNLME RMSE",
)
hidexdecorations!(ax)
boxplot!(ax, ones(length(diffs)), diffs)

display(f);

It is clear that the DeepNLME model does a better job at fitting the data since the median and 5% and 95% quantiles are above 0. We finish the tutorial by comparing the learned DeepNLME vs. the classical model feedback terms.

6 Comparing the learned feedback terms

Since we used the neural network to represent a very specific term in the model, we can now inspect and contrast it against the classical model feedback term.

classical_feedback(x_ma, B, γ) = (B / x_ma)^γ
deepnlme_feedback(x_ma) = only(params_deepnlme.tvF_B(zeros(2), x_ma / 10.0))

params_classical = coef(fpm_classical)
params_deepnlme = coef(fpm_deepnlme)

x_ma_vals = 0:0.01:20

f_comp = Figure()
ax_comp =
    Axis(f_comp[1, 1], limits = (0, 20, 0, 4), xlabel = L"x_{ma}", ylabel = L"F_B(x_{ma})")
lines!(
    ax_comp,
    collect(x_ma_vals),
    classical_feedback.(x_ma_vals, params_classical.tvB, params_classical.tvγ),
    label = "Classical",
    color = :black,
)
lines!(
    ax_comp,
    collect(x_ma_vals),
    deepnlme_feedback.(x_ma_vals),
    label = "DeepNLME",
    color = :red,
)
Legend(f_comp[1, 2], ax_comp)

display(f_comp);

As shown in the plot, the learned neural feedback function is more stable at low \(x_{ma}\) and it decreases faster than the classical feedback with increasing \(x_{ma}\).

7 Conclusion

In this tutorial we applied DeepPumas on a real-world data. We used a neural network to improve upon a feedback term in the Friberg model by avoiding extreme behaviours, potential numerical issues and gaining functional expressiveness. The resulting DeepNLME model outperformed the classical model on held out data in both marginal likelihood and RMSE values. Therefore, this tutorial provides a convincing use case showing that DeepPumas can address challenges faced by traditional PK/PD modelling in an efficient, data-driven manner under real-world circumstances.