using DeepPumas
using CairoMakie
using AlgebraOfGraphics
using Random
using DataFramesMeta
set_theme!(deep_light())
DeepPumas Epidemiology: force-of-infection across heterogeneous countries
DeepPumas allows us to integrate data-driven function identification that accounts for variability within different stratifications of our dataset. In pharmacology, this approach is often used to identify disease-specific functions that are partially individualized across patients. However, a population-subject split is only one of many possible and relevant ways to stratify data.
In this tutorial, we will use DeepPumas to analyze epidemiological data, focusing on a disease that exhibits both similarities and differences in its spread across countries. The nonlinear mixed effects (NLME) framework enables us to effectively manage data from heterogeneous sources, even when it is unbalanced. Each country may differ not only in the mechanisms of disease transmission but also in the availability and quality of data. A key question is whether we can leverage abundant data from certain countries to enhance predictions in countries where data is sparse.
This is a novel application area not only for DeepNLME but also for NLME frameworks in general. While our analysis will focus on a specific question, DeepNLME enables us to separate common trends from individual variations across datasets, opening up a broad range of possible questions to explore.
Here, we will study the force-of-infection of a virus that individuals typically acquire over the course of their lives and do not eliminate. In this case, our data is cross-sectional rather than longitudinal - the percentage of people testing positive for the virus increasing by age.
We will start by loading a few necessary packages and setting a theme to enhance the visual clarity of our plots.
1 Synthetic data generation
We next generate a synthetic data set where the fraction of positive tests is stratified by country and age, where the country significantly affects the relationship between the age and the positive test rates.
We start by defining a force-of-infection function \(λ\) that depends on age.
λ(age) =
pdf.(Normal(16, 7), age) .+ 1.2 .* pdf.(Normal(37, 10), age) .+
0.4 .* pdf.(Normal(65, 14), age) .+ 0.02
= 0:0.1:100
age = Figure()
fig = Axis(
ax 1, 1],
fig[= (nothing, nothing, 0, nothing),
limits = "Age (years)",
xlabel = "Probability Density",
ylabel
)lines!(ax, age, λ)
display(fig);
Next, we define a data-generating model that uess the force-of-infection, applying a country-specific scaling factor to simulate synthetic data for various countries, each with differing numbers of measurements. For reference, the complete code is provided in a collapsible box below, though the model details are not essential for following this tutorial. We also plot the generated training data.
Click to see the data generating model and data generation.
= @model begin
datamodel @param begin
∈ RealDomain(; lower = 0, init = 10)
σ ∈ RealDomain(; lower = 0)
tvc ∈ RealDomain(; lower = 0)
ω end
@random η ~ Normal(0, ω)
@pre begin
= tvc * exp(η)
c = t -> c * λ(t)
_λ end
@init S = 1
@dynamics begin
' = -_λ(t) * S
Send
@derived begin
"""
Positive test fraction
"""
~ @. Beta(σ * (1 - 0.99S), σ * (0.99S))
PosFrac end
@observed begin
= _λ[1]
iλ end
end
= [
countries "Austria",
"Belgium",
"Bulgaria",
"Croatia",
"Cyprus",
"Czech Republic",
"Denmark",
"Estonia",
"Finland",
"France",
"Germany",
"Greece",
"Hungary",
"Ireland",
"Italy",
"Latvia",
"Lithuania",
"Sweden",
]= ["Brazil", "Vietnam"]
sparse_countries
= (; tvc = 0.3, σ = 1000.0, ω = 0.6)
data_params = vcat(0:5, 10:5:100)
obstimes Random.seed!(123)
= map(eachindex(countries)) do i
sim_data simobs(
datamodel,Subject(; id = countries[i]),
data_params;= sort(sample(obstimes, rand(1:20); replace = false)),
obstimes
)end
= map(eachindex(sparse_countries)) do i
sim_data_sparse simobs(
datamodel,Subject(; id = sparse_countries[i]),
data_params;= sort(sample(1:100, rand(1:2); replace = false)),
obstimes
)end
= vcat(countries, sparse_countries)
all_countries = vcat(sim_data, sim_data_sparse)
all_sims = Subject.(all_sims)
all_data
= all_data[1:end-4]
train_data = all_data[end-3:end]
val_data
plotgrid(
all_data;= "AGE",
xlabel = "Fraction positive tests",
ylabel = false,
legend = (; markersize = 10),
data = (; yticks = 0:0.25:1),
axis = (s, i) -> "$(s.id) ($(s.id in getfield.(val_data, :id) ? "Test" : "Train"))",
title )
2 Modeling the data
Now, observing this data — but assuming we do not know the original data-generating model — we can define a simple model. We represent the fraction of the population that is susceptible to the virus as \(S\), starting from 1 for newborns. This fraction decreases with age, at a rate governed by an age-dependent force-of-infection and the remaining proportion of people who are still susceptible at each age. Note that while our independent variable is age
, DeepPumas represents it as time t
.
The susceptible fraction of the population for country \(i\), \(S_i\), is defined as follows:
\[\begin{equation} \frac{dS_i}{dAge} = - λ_i(Age) * S_i \end{equation}\]
where \(λ_i(Age)\) is unknown. We use a neural network to estimate \(λ_i(Age) \approx NN(Age, \eta_i)\), where we assume a single, shared functional form for the force-of-infection across countries. Rather than modeling it with entirely independent functions, we allow the function to be adjusted by a single, tunable random effect, \(\eta_i\), specific to each country. When fitting to a marginalizing likelihood (FO, FOCE, LaplaceI), this single degree of freedom — \(\eta\) — allows us to account for cross-country variability effectively.
The theoretical implications of using marginalizing likelihoods in NLME models go beyond the scope of this tutorial, but in essence, this approach allows the embedded neural network to capture common patterns across the data, while relying on a single dimension of variability that meaningfully influences outcomes across countries. Using two random effects would allow us to capture an additional dimension of variability.
The neural network itself is a simple multi-layer perceptron (MLP) with normalized age, \(Age/100\), and \(\eta\) as inputs. It consists of two hidden layers, each with five nodes using the tanh
activation function, and a single output node with a softplus
activation function to ensure positive output.
The data contains noise, which is likely related to the population size, \(N\), used to compute the fraction of positive test results for each age group (with ages binned in 1-year intervals). While in some datasets we might have direct access to \(N\), here we estimate it. We model the observed fraction of positive tests (\(1 - S\)) as samples from a Beta((1 - S) * N, S * N)
distribution. To avoid numerical and boundary issues, we modify this slightly to Beta(abs(1 - 0.99 * S) * N, abs(0.99 * S * N))
.
= @model begin
model @param begin
∈ RealDomain(; lower = 0.0, init = 1000.0)
N ∈ MLPDomain(2, 5, 5, (1, softplus))
λ end
@random η ~ Normal(0, 0.1)
@init S = 1
@dynamics begin
' = -λ(t / 100, η)[1] * S
Send
@derived PosFrac ~ @. Beta(abs(N * (1 - 0.99S)), abs(N * 0.99S))
@observed begin
= age -> λ[1](age / 100, η[1])[1]
iλ end
end
With the model defined, we fit it to maximize the Laplace-approximated maximum aposteriori, MAP(LaplaceI())
.
= fit(
fpm
model,
train_data,init_params(model),
MAP(LaplaceI());
= false,
checkidentification = (; show_every = 100),
optim_options )
Having fitted the model, we now apply it to all data — including the test data we withheld during training.
Our predict
function generates two types of predictions. The first is the “Central tendency”, also known as the “Population prediction”. This is a true prediction that does not “peek” at the data being predicted. Since we do not have covariates or additional information beyond the observed infection proportion, this prediction will be identical across all countries. For this prediction, we set the \(\eta\) value to the mean of the prior distribution.
The second type of prediction uses \(\eta\) set to the mean of the approximated random effect posterior distribution, which is based on all the observations available from the country in question. This is often referred to as an “Individual prediction”.
Our goal is to make accurate predictions of how the force-of-infection and the burden of disease change with age in countries with sparse data. To evaluate the effectiveness of the model, we make individual predictions for all countries — including those withheld from training — and plot the results, overlaying them with the observed data. We also overlay the latent, noise-free, truth that we have available since we know how the data was generated.
plotgrid(
predict(fpm, all_data; obstimes = 1:100);
= "AGE",
xlabel = "Fraction positive tests",
ylabel = (; label = "Central tendency", color = (:black, 0.5)),
pred = (; label = "Individual prediction"),
ipred = (s, i) -> "$(s.id) ($(s.id in getfield.(val_data, :id) ? "Test" : "Train"))",
title = (; markersize = 10),
data = (; yticks = 0:0.25:1),
axis
)plotgrid!(
simobs(
datamodel,
all_data,
data_params,getfield.(all_sims, :randeffs);
= false,
simulate_error = 1:100,
obstimes
);= (; linestyle = :dash, markersize = 0, linewidth = 2, label = "Truth"),
sim )
The fitted model captures the observed data and underlying trends remarkably well, even for countries like Vietnam, which in this case has only two observations. Without leveraging information on typical infectiousness patterns from other countries, estimating the full infection-age curve for Vietnam would have been impossible.
The model accurately matches positive test rates by identifying a country-specific functional form for the force-of-infection. This estimated force-of-infection curve can, of course, be extracted directly from the fitted model.
= fill(age, length(all_countries))
ages
= empirical_bayes(model, all_data, coef(fpm), FOCE())
model_ebes = simobs(model, all_data, coef(fpm), model_ebes)
model_sims = map(s -> s.observed.iλ.(age), model_sims)
predicted
# Since we use synthetic data, we have the luxury of extracting the ground truth for comparison.
= map(s -> s.observed.iλ.(age), all_sims)
actual
= DataFrame(; all_countries, ages, actual, predicted)
_df = DataFrames.flatten(_df, [:ages, :actual, :predicted])
_df = stack(_df, [:actual, :predicted], variable_name = :λ)
_df
= (width = 150, height = 150)
axis =
plt data(_df) *
mapping(:ages => "Age", :value => "λ(Age)", color = :λ, layout = :all_countries)
*= visual(Lines)
plt draw(plt; axis = axis)
Here again, we see that the model estimates align closely with the true values used to generate the data. The model captures the main trends effectively, even in sparsely observed test data. Notably, Vietnam — excluded from model fitting, with only two observations and a distinct force-of-infection profile — was still well-predicted.
3 Conclusions
In summary, the model successfully:
- Identified a reasonable functional form for the force-of-infection,
- Individualized this force-of-infection using a single random effect,
- Found an appropriate transformation of the random effect distribution, making the prior informative. This transformation helps encode information from rich data sources, guiding accurate predictions even in countries with sparse data
This tutorial provides just one example of how DeepNLME and DeepPumas can elucidate complex dynamics using data from heterogeneous sources. The same methodology can be applied to more intricate models, such as SIR models where time (rather than age) is the independent variable. Furthermore, after characterizing between-country (or entity-specific) variability through random effects, we can explore predictive patterns in geographical or demographic data to help account for some of this variability.