# Bring packages to environment
import Pkg
Pkg.activate(".")
Pkg.instantiate()
using KernelFunctions, PharmaDatasets, Statistics
using MLJ, LinearAlgebra, CSV, DataFrames
using Random, Distances, Clustering, MultivariateStats
Random.seed!(378); # Set random seed for reproducibility
# Load models
= @load KMedoids pkg = Clustering
KMedoids = @load EvoTreeClassifier pkg = EvoTrees
EvoTreeClassifier = @load KNNClassifier pkg = NearestNeighborModels
knn = @load ContinuousEncoder pkg = MLJModels
ContinuousEncoder = @load KernelPCA pkg = MultivariateStats
KernelPCA = @load PCA pkg = MultivariateStats PCA
data:image/s3,"s3://crabby-images/5bf99/5bf9981d0276487b448ca81be06f1e8780483bd8" alt="Pumas Logo"
Machine learning fundamentals
1 Introduction
In the context of the first unit in the “AI for drug development” course, this tutorial provides a didactic use of real data and study for the introduction of machine learning (ML) fundamentals, rather than robust clinical tools. The discussion here is guided towards professionals and academics from healthcare, specially pharmacometrics.
Furthermore, the workflow we present here assumes no prior knowledge about AI. The focus of each section is as follows:
- We begin by summarizing the context of the study below.
- In Section 2, common data-processing steps are presented, such as imputation, one-hot encoding, and standardization.
- In Section 3, the distinction between supervised and unsupervised application is explained. There, some unsupervised models are discussed as well, focusing in dimensionality reduction (through principal component analysis) and clustering (including K-Medoids and K-Means).
- In Section 4, we present decision trees, ensembles, random forest, and gradient-boosted trees. Also in that section, options for model evaluation in binary classification are detailed. Lastly, K-nearest neighbors is explained, and references for further study are listed.
In the case study analyzed, the goal is to determine the main predictors of response to a treatment against critical limb ischemia (CLI), a disease in advanced stage, and characterized by the blockage of arteries in a limb. The authors of “Characteristics of responders to autologous bone marrow cell therapy for no-option critical limb ischemia” (Madaric et al. 2016) made their dataset available as a spreadsheet, containing multiple features about the 62 patients that initially enrolled in the study. The endpoints of interest in the study are:
- Primary endpoints:
- Responder: surviving patients with limb salvage and wound healing.
- Non-responders: patients requiring major limb amputation or those with no signs of wound healing.
- Secondary endpoints:
- Mortality.
- Amputation-free survival.
- Major limb amputation.
- Change in tcpO2.
- Rutherford category.
- Pain scale after BMC transplantation.
And, from the point of view of ML, this leads to a classification problem. In other words, there is a finite amount of discrete possible outcomes. Further, the problem is specifically of binary classification: either a patient responded to the treatment or not. In contrast, a regression problem has a continuous outcome, such as forecasting energy demand or traffic congestion. Other high-level groupings of ML algorithms can be found at (Brownlee 2023).
It’s important to point out that a dataset with 62 data points is extremely small for ML applications. In general, datasets with hundreds of points are considered small. Still, the main concepts can be described, although the specific performance values obtained are somewhat didactic.
First, some Julia packages will be necessary. The MLJ.jl package will be used throughout the tutorial, but the ML concepts and techniques discussed here are tool-agnostic.
2 Data processing
ML projects tend to follow very similar steps (Mehreen 2022). To start the analysis, the file is read. And some preprocessing is done. This includes conversion of occurrences of empty strings (““) into the value missing
. In the data overview, the nmissing
column indicate the number of missing values for that dataset column.
# Read spreadsheet data
= CSV.read(dataset("CLIstemCell", String), DataFrame; missingstring = ["", "#NA!"])
df describe(df)[:, Not(7)] |> println; # Print overview
47×6 DataFrame
Row │ variable mean min median max nmissi ⋯
│ Symbol Float64 Real Float64 Real Int64 ⋯
─────┼──────────────────────────────────────────────────────────────────────────
1 │ PatientID 31.5 1 31.5 62 ⋯
2 │ HealedAmputFreeSurvival 0.532258 0.0 1.0 1.0
3 │ Survival 0.887097 0.0 1.0 1.0
4 │ LimbSalvage 0.709091 0.0 1.0 1.0
5 │ AmputFreeSurvival 0.629032 0.0 1.0 1.0 ⋯
6 │ HealedLimbSalvage 0.6 0.0 1.0 1.0
7 │ tcpO2 13.9194 0.0 13.5 33.0
8 │ tcpO2 at 6M 29.65 1 30.0 67
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
41 │ sICAM-1 744.629 256.36 622.97 3220.08 ⋯
42 │ sICAM-3 126.915 0.0 105.89 399.11
43 │ sEselectin 140.433 46.61 128.04 670.85
44 │ sPselectin 803.58 79.24 640.11 4043.14
45 │ sPECAM-1 418.588 219.14 379.9 832.93 ⋯
46 │ VEGF 170.164 15.675 68.239 709.564
47 │ ApplicationRoute 0.516129 0.0 1.0 1.0
1 column and 32 rows omitted
2.1 Imputation
Now, the dataset contains only missing
s and numerical values of types Float64
and Int64
. The next data-processing step is called imputation, assigning a value to all missing
s. For this, the MLJ.jl Julia package has the FillImputer()
utility, which, by default, uses the median or the mode per column according to the type of data. This is a very simple approach for imputation, and many alternatives exist (Airbyte 2024).
# Impute missing values by median and mode
= machine(FillImputer(), df) # Attach model to data
machImp fit!(machImp, verbosity = 0)
MLJ.= MLJ.transform(machImp, df) # Perform imputation
impDF describe(impDF)[:, Not(7)] |> println; # Overview
47×6 DataFrame
Row │ variable mean min median max nmissi ⋯
│ Symbol Float64 Real Float64 Real Int64 ⋯
─────┼──────────────────────────────────────────────────────────────────────────
1 │ PatientID 31.5 1 31.5 62 ⋯
2 │ HealedAmputFreeSurvival 0.532258 0.0 1.0 1.0
3 │ Survival 0.887097 0.0 1.0 1.0
4 │ LimbSalvage 0.741935 0.0 1.0 1.0
5 │ AmputFreeSurvival 0.629032 0.0 1.0 1.0 ⋯
6 │ HealedLimbSalvage 0.645161 0.0 1.0 1.0
7 │ tcpO2 13.9194 0.0 13.5 33.0
8 │ tcpO2 at 6M 29.7742 1 30.0 67
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
41 │ sICAM-1 711.271 256.36 622.97 3220.08 ⋯
42 │ sICAM-3 121.15 0.0 105.89 399.11
43 │ sEselectin 137.435 46.61 128.04 670.85
44 │ sPselectin 758.757 79.24 640.11 4043.14
45 │ sPECAM-1 407.98 219.14 379.9 832.93 ⋯
46 │ VEGF 152.081 15.675 68.239 709.564
47 │ ApplicationRoute 0.516129 0.0 1.0 1.0
1 column and 32 rows omitted
Note that, despite the nmissing
column now only containing zeros, the means and medians did not change much from the previous step. This is a good sign, indicating that the imputation process didn’t have a significant impact in the original data distribution. Now, as an exploration step, the histograms in Figure 1 can now be plotted to get more familiar with the data.
# Activate plotting backend
activate!()
GLMakie.= Figure(; fontsize = 18) # Create figure to draw in
fig # Axis and plot for each histogram
= Axis(fig[1, 1]; title = "Age (y)")
ax hist!(ax, df[:, "Age"])
= Axis(fig[1, 2]; title = "Cholesterol (mmol/L)")
ax hist!(ax, df[:, "Cholesterol"])
= Axis(fig[2, 1]; title = "Leu (10⁹/L)")
ax hist!(ax, df[:, "Leu"])
= Axis(fig[2, 2]; title = "Creatinine (μmol/L)")
ax hist!(ax, df[:, "Creatinine"])
save("./case1_hists.png", fig) # Save PNG
2.2 Coercing data types
Another data-processing step, necessary for MLJ.jl, is managing the columns’ “scientific data type”, such as Continuous
, Count
, Binary
etc. They are used to clearly define the list of ML models compatible with the inputs and outputs of an application.
In the current case, the inputs are “coerced” into Continuous
; and the outputs, into Binary
(if the patient responded to the treatment or not). In the following, compatible models are also listed to showcase the diversity of ML models.
# Subset of columns used as inputs
= vcat([7, 10, 13], 16:32, [35, 38], 41:46)
colIDs = impDF[:, colIDs] # 62x28
inputs # Primary endpoint as binary prediction target
= collect(impDF[:, "HealedAmputFreeSurvival"])
outputs # Change scientific type of column "Rutherford"
= coerce(inputs, Symbol("Rutherford") => OrderedFactor)
inputs # Coerce data types of remaining columns
= machine(ContinuousEncoder(; one_hot_ordered_factors = true), inputs)
machCont fit!(machCont, verbosity = 0)
= MLJ.transform(machCont, inputs) # 62x29
coercedInputs = coerce(outputs, Binary) # Coerce target into binary
outputs # List of available models in MLJ.jl
models(matching(coercedInputs, outputs)) .|> println;
(name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )
(name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )
(name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )
(name = BayesianLDA, package_name = MultivariateStats, ... )
(name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )
(name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )
(name = CatBoostClassifier, package_name = CatBoost, ... )
(name = ConstantClassifier, package_name = MLJModels, ... )
(name = DecisionTreeClassifier, package_name = BetaML, ... )
(name = DecisionTreeClassifier, package_name = DecisionTree, ... )
(name = DeterministicConstantClassifier, package_name = MLJModels, ... )
(name = DummyClassifier, package_name = MLJScikitLearnInterface, ... )
(name = EvoTreeClassifier, package_name = EvoTrees, ... )
(name = ExtraTreesClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GaussianNBClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GaussianNBClassifier, package_name = NaiveBayes, ... )
(name = GaussianProcessClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GradientBoostingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = HistGradientBoostingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = KNNClassifier, package_name = NearestNeighborModels, ... )
(name = KNeighborsClassifier, package_name = MLJScikitLearnInterface, ... )
(name = KernelPerceptronClassifier, package_name = BetaML, ... )
(name = LDA, package_name = MultivariateStats, ... )
(name = LGBMClassifier, package_name = LightGBM, ... )
(name = LinearBinaryClassifier, package_name = GLM, ... )
(name = LinearSVC, package_name = LIBSVM, ... )
(name = LogisticCVClassifier, package_name = MLJScikitLearnInterface, ... )
(name = LogisticClassifier, package_name = MLJLinearModels, ... )
(name = LogisticClassifier, package_name = MLJScikitLearnInterface, ... )
(name = MultinomialClassifier, package_name = MLJLinearModels, ... )
(name = NeuralNetworkBinaryClassifier, package_name = MLJFlux, ... )
(name = NeuralNetworkClassifier, package_name = BetaML, ... )
(name = NeuralNetworkClassifier, package_name = MLJFlux, ... )
(name = NuSVC, package_name = LIBSVM, ... )
(name = PassiveAggressiveClassifier, package_name = MLJScikitLearnInterface, ... )
(name = PegasosClassifier, package_name = BetaML, ... )
(name = PerceptronClassifier, package_name = BetaML, ... )
(name = PerceptronClassifier, package_name = MLJScikitLearnInterface, ... )
(name = ProbabilisticNuSVC, package_name = LIBSVM, ... )
(name = ProbabilisticSGDClassifier, package_name = MLJScikitLearnInterface, ... )
(name = ProbabilisticSVC, package_name = LIBSVM, ... )
(name = RandomForestClassifier, package_name = BetaML, ... )
(name = RandomForestClassifier, package_name = DecisionTree, ... )
(name = RandomForestClassifier, package_name = MLJScikitLearnInterface, ... )
(name = RidgeCVClassifier, package_name = MLJScikitLearnInterface, ... )
(name = RidgeClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SRRegressor, package_name = SymbolicRegression, ... )
(name = SVC, package_name = LIBSVM, ... )
(name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )
(name = StableForestClassifier, package_name = SIRUS, ... )
(name = StableRulesClassifier, package_name = SIRUS, ... )
(name = SubspaceLDA, package_name = MultivariateStats, ... )
(name = XGBoostClassifier, package_name = XGBoost, ... )
And when it comes to what features to use, coldIDs
includes almost all of them, excluding most secondary endpoints. Another information not included were snapshots after baseline. Snapshots are features that were measured multiple times at different moments: right before treatment (baseline), and 6 and 12 months after treatment, namely tcpO2, ABI (ankle-brachial index), wound size, and Rutherford, pain, and quality of life scales.
Since the context of the ML models developed in this tutorial is to predict if a new patient suffering from CLI would benefit from the proposed treatment, only the baseline values of snapshot variables will be used. However, one can imagine using such models 6 months after treatment, to have a more precise prediction of further development of the patients condition in the next 6 months. This exercise of including the snapshots at 6 months is left as an exercise for the reader.
Moving on, in the previous code, other than coercing, another data-processing step performed is one-hot encoding, included in ContinuousEncoder(; one_hot_ordered_factors = true)
. For a variable that can only assume a finite set of possible discrete values (called “categorical”), this process replaces the variable with multiple binary variables.
For example, in the current case, there are features that include a rating in the Rutherford scale for critical limb ischemia, which measures the severity of the disease in 6 discrete levels. One of these features is called “Rutherford at 6M” (column 11), referring to measurements of this scale at 6 months after treatment.
= impDF[:, [11, 13]] # "Rutherford at 6M" and "Pain" columns (62x2)
oneHotDF = coerce(oneHotDF, Symbol("Rutherford at 6M") => OrderedFactor)
oneHotDF = machine(ContinuousEncoder(; one_hot_ordered_factors = true), oneHotDF)
machOneHot fit!(machOneHot, verbosity = 0)
first(oneHotDF, 10) |> println # Overview before one-hot encoding
println()
= MLJ.transform(machOneHot, oneHotDF) # 62x7 DataFrame
oneHotDF first(oneHotDF, 10) |> println; # Overview after one-hot encoding
10×2 DataFrame
Row │ Rutherford at 6M Pain
│ Categorical… Float64
─────┼───────────────────────────
1 │ 3 0.0
2 │ 5 4.0
3 │ 5 7.0
4 │ 5 5.0
5 │ 2 0.0
6 │ 1 3.0
7 │ 5 6.0
8 │ 3 0.0
9 │ 5 5.0
10 │ 5 6.0
10×7 DataFrame
Row │ Rutherford at 6M__1 Rutherford at 6M__2 Rutherford at 6M__3 Rutherfo ⋯
│ Float64 Float64 Float64 Float64 ⋯
─────┼──────────────────────────────────────────────────────────────────────────
1 │ 0.0 0.0 1.0 ⋯
2 │ 0.0 0.0 0.0
3 │ 0.0 0.0 0.0
4 │ 0.0 0.0 0.0
5 │ 0.0 1.0 0.0 ⋯
6 │ 1.0 0.0 0.0
7 │ 0.0 0.0 0.0
8 │ 0.0 0.0 1.0
9 │ 0.0 0.0 0.0 ⋯
10 │ 0.0 0.0 0.0
4 columns omitted
One-hot encoding turns this categorical column of 6 levels into 6 binary variables: “Rutherford at 6M__i”, for i from 1 to the number of unique values of the categorical variable. Across these new binary variables, each patient has 0 in all of them except one, which is 1 and indicates the value in the original categorical column.
Lastly, it’s worth pointing out that, in the variable coercedInputs
, which will be used in this tutorial’s pipeline, the Rutherford measurements at baseline (beginning of study) were used. And, since only patients in critical condition participated in the study, all values are 5 or 6. Therefore, only two new binary variables were created: “Rutherford__5” and “Rutherford__6”.
2.3 Standardization
Afterwards, each column is standardized, which means subtracting the mean and dividing by the standard deviation.
\[ \bar{x_{i}} = \frac{x_{i} - \mathbb{E}[x]}{\sqrt{Var[x]}} \]
As a result, all inputs have null mean and unitary standard deviation. This alleviates some numerical problems that may show up during training. An alternative is quantile normalization (Bolstad et al. 2003), which can be achieved through MLJ’s UnivariateDiscretizer.
# Standardization
= machine(Standardizer(), coercedInputs)
machStd # Columns' mean and std calculated through 'fitting'
fit!(machStd, verbosity = 0)
MLJ.# Use established mean and std to transform data
= MLJ.transform(machStd, coercedInputs)
stdInputs # Check transformation
= eachcol(stdInputs)
stdCols DataFrame(
= names(stdInputs),
Feature = minimum.(stdCols),
Min = mean.(stdCols),
Mean = maximum.(stdCols),
Max = std.(stdCols),
SD |> println; )
29×5 DataFrame
Row │ Feature Min Mean Max SD
│ String Float64 Float64 Float64 Float64
─────┼──────────────────────────────────────────────────────────────────
1 │ tcpO2 -1.42178 -5.49236e-17 1.94898 1.0
2 │ Rutherford__5.0 -3.03031 7.52087e-17 0.324676 1.0
3 │ Rutherford__6.0 -0.324676 -2.68602e-18 3.03031 1.0
4 │ Pain -2.00892 -6.46884e-17 2.22758 1.0
5 │ Sex -0.381784 5.37205e-18 2.57704 1.0
6 │ Age -2.39604 4.67368e-16 2.27203 1.0
7 │ ArterialHypertension -1.83665 3.58136e-18 0.535689 1.0
8 │ DiabetesMellitus -1.38596 1.14604e-16 0.709883 1.0
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮
23 │ Wound size base -0.762213 -2.14882e-17 3.62204 1.0
24 │ sICAM-1 -1.12333 4.92438e-17 6.19511 1.0
25 │ sICAM-3 -1.85742 -5.30937e-16 4.26157 1.0
26 │ sEselectin -1.07723 -3.95293e-16 6.32657 1.0
27 │ sPselectin -1.27739 -2.05033e-16 6.17412 1.0
28 │ sPECAM-1 -1.45862 -6.81355e-16 3.28237 1.0
29 │ VEGF -0.770142 1.45045e-16 3.14753 1.0
14 rows omitted
3 Unsupervised learning
In ML, each data point is usually an \((input, label)\) pair. When a model processes the input value, it’s expected that it will output the \(label\) value, or something at least close to it.
This rationale defines supervised ML applications, ones with a label associated with each data point. Common examples of pairs are \((image, category)\), or, for the current problem, \((patient \: data, treatment \: response)\).
On the other hand, the ML subfield of unsupervised learning builds models for data without labels. Its main uses are to gain insights about the data and learn patterns, instead of making predictions.
3.1 Dimensionality reduction
And, inside unsupervised learning, dimensionality reduction (Jacob Murel 2024) can be used to reduce the number of dimensions of the data. This is useful to (Géron 2019):
- Avoid the curse of dimensionality.
- Compress the information in the data.
- Visualize the dataset.
- Speed up training.
A great example of an unsupervised learning technique is Principal Component Analysis (Jaadi 2024). It projects the data points to a space of lower dimension, searching for orthogonal axes that account for as much of the data’s variance as possible. Figure 2 illustrates the information retention in PCA, as more principal components are included.
= Matrix(stdInputs) # Inputs as matrix
matIn = MultivariateStats.fit(MultivariateStats.PCA, matIn'; maxoutdim = 31)
pca = eigvals(pca) # Eigenvalues
valsP = Figure(; size = (800, 500), fontsize = 18)
fig = Axis(
ax 1, 1],
fig[= "Number of principal components",
xlabel = "Percentage of variance",
ylabel = 0:10:100,
yticks = 1:2:31,
xticks = "{:d}%",
ytickformat
)= cumsum(valsP ./ sum(valsP)) .* 100
cumulativePercentage lines!(ax, cumulativePercentage)
save("./eigVar.png", fig)
In the current case, the patient vectors from the previous step can be mathematically interpreted as points in a 31-dimensional space. In the following, PCA is used to visualize the scatter plot in Figure 3 of the data after projecting the points to 3-dimensional space.
= machine(PCA(; maxoutdim = 3), stdInputs)
machPCAviz fit!(machPCAviz, verbosity = 0) # Compute principal components
MLJ.# Project data to lower dimension
= MLJ.transform(machPCAviz, stdInputs); lowInputsViz
A limitation of standard PCA is the assumption of linearity between features. This problem is solved in Kernel PCA (Thompson 2018), which first uses a ‘kernel function’ to map the data to a mathematical space of higher dimension, usually recovering the linear separability of the points.
As a side example, the data could be projected from 29 to 10 dimensions using a squared exponential kernel. Further, the mean squared reconstruction error is calculated to be 0.84, indicating that the dimension reduction is representing the data well.
= machine(
machKPCA KernelPCA(; maxoutdim = 10, kernel = (x, y) -> SqExponentialKernel()(x, y)),
stdInputs,
)fit!(machKPCA, verbosity = 0)
MLJ.= MLJ.transform(machKPCA, stdInputs)
KPCA_low_inputs # Reconstruction
= inverse_transform(machKPCA, KPCA_low_inputs) |> DataFrame |> Matrix
reconst # Mean squared reconstruction error
= (Matrix(stdInputs) .- reconst) .^ 2 |> mean; reconstError
However, for the current case, PCA leads to a mean squared reconstruction error of 0.28. As a result, results from PCA dimensionality reduction will be used for the remainder of the tutorial.
= machine(PCA(; maxoutdim = 10), stdInputs)
machPCA fit!(machPCA, verbosity = 0)
MLJ.= MLJ.transform(machPCA, stdInputs)
lowInputs # Reconstruction
= inverse_transform(machPCA, lowInputs) |> DataFrame |> Matrix
reconstPCA # Mean reconstruction error
= (Matrix(stdInputs) .- reconstPCA) .^ 2 |> mean; PCAreconstError
3.2 K-medoids
It’s important to remember that supervised and unsupervised techniques are not antagonistic. Each has their uses, pros and cons. In fact, they can be combined. For the rest of this tutorial, the previous 10-dimensional KPCA data projection is used as the input to the models.
Now the unsupervised learning technique K-medoids will be used (H 2024). Its goal is clustering: splitting the data according to groupings of the points. The following steps are required:
- The amount of clusters is specified by K.
- K data points are chosen as the centers (medoids) of each cluster.
- The remaining points are assigned to the cluster of the closest center.
- The K medoids are adjusted to minimize the sum of distances from them to all points assigned to their cluster.
K-means is another common clustering algorithm. The main difference to k-medoids is that the centers of clusters don’t need to be data points, being instead defined by the average of all points in a cluster.
= machine(KMedoids(; k = 2), lowInputs)
machKM fit!(machKM, verbosity = 0)
MLJ.# Predict point assignments
= MLJ.levelcode.(MLJ.predict(machKM, lowInputs)); KMpredictions
3.3 Tangent: hyperparameters
In K-medoids, K is a hyperparameter: a parameter that the model does not learn during training. Rather, this type of parameter is manually set by the user when setting up the model, before training.
From the study’s context of predicting if patients will respond to the treatment or not, here the number K of clusters is set to 2. However, a lot of times, the parameter K in clustering techniques is not known beforehand.
In fact, a common use of clustering is to determine K itself. This translates to revealing if the data points are potentially organized in groups, which tend to be unrecognizable by manual data analysis, specially in high-dimensional spaces.
And to determine hyperparameters, optimization techniques can be applied. For this, usually the dataset will be split between training and validation sets. In such a context, the optimization tends to follow the general procedure in Figure 4. The Tuning Models page in MLJ.jl’s documentation includes examples of strategies.
4 Supervised learning
One of the most popular supervised learning algorithms are decision trees. They work by repeatedly splitting the dataset, one feature at a time, by establishing thresholds on the features.
For example, a node may split the patients between those who are younger that 60 years old or not. And, among those younger, another node may do another split between those with cholesterol level lower than 4 mmol/L or not, and so on. Figure 5 illustrates such a decision tree.
And, simplifying the problem to only these 2 features, Figure 6 contains a scatter plot showing the splits of the data and the resulting classifications of patients. The splitting continues until the number of samples to which the node refers is too small, or the samples are concentrated in a specific value or interval of the corresponding feature.
activate!() # Plotting backend
GLMakie.# Indices of patients in each category
= inputs.Age .< 60 .&& inputs.Cholesterol .< 4
lowYoung = inputs.Age .< 60 .&& inputs.Cholesterol .>= 4
highYoung = inputs.Age .>= 60
older = Figure(; fontsize = 18) # Figure to draw in
fig = Axis(fig[1, 1])
ax vlines!(ax, 60; color = :brown) # Age split
= inputs[:, :Age]
age # Function for relative position in x axis
ageRel(x) = (x - minimum(age)) / (maximum(age) - minimum(age))
# Cholesterol split
hlines!(ax, 4; xmin = ageRel(38), xmax = ageRel(60.4), color = :blue)
# Scatter plots
= scatter!( # Non-responders on 1st level
plot1
ax,"Age"],
inputs[older, "Cholesterol"];
inputs[older, = :green,
color = :diamond,
marker = 13,
markersize
)= scatter!( # Non-responders on 2nd level
plot2
ax,"Age"],
inputs[highYoung, "Cholesterol"];
inputs[highYoung, = :purple,
color = :xcross,
marker = 13,
markersize
)= scatter!( # Responders
plot3
ax,"Age"],
inputs[lowYoung, "Cholesterol"];
inputs[lowYoung, = :red,
color = :pentagon,
marker = 13,
markersize
)Legend( # Labels for splits
1, 2],
fig[
[plot1, plot2, plot3],"NR 1st level", "NR 2nd level", "Responders"],
[
)# Save PNG
save("./decisionTreeScatter.png", fig)
Decision trees are very popular in ML, since they are interpretable, versatile, and scalable. However, they can be too sensitive to small or superficial changes in the dataset. To overcome this, they can be used in ensembles, that combine multiple simpler models into a single more robust learner (Brownlee 2024).
For decision trees, common ensembles are random forests. They contain multiples trees, each one being trained on a random subset of the data. Each tree also uses only a random subset of the features, to avoid building many similar trees, in case there are a few features with high predictive power. Usually, random forests use a voting system among trees for classification; and average the trees’ outputs for regression.
4.1 Gradient boosted trees
Another type of ensemble is called boosting (Winston 2014): training the weak learners sequentially, each one focusing on cases that posed a challenge to the previous model.
In gradient boosting (Starmer 2019) specifically, the first model \(M_1\) is trained to map features \(X\) to labels \(Y\), so \(M_1: X \rightarrow Y\) as usual. Then, the second model \(M_2\) is instead trained to map the residuals of the first one (\(r = M_1(X) - Y\)) to the labels: \(M_{2}: r \rightarrow Y\). A third model would do the same to the second one, and so on. Lastly, the final prediction made by the ensemble is the sum of all these intermediate outputs.
Such a procedure equates to stacking consecutive corrections to the initial prediction. This ensemble is used below as the base in the TunedModel
MLJ.jl utility, which provides workflows for hyperparameter optimization.
= EvoTreeClassifier() # Base model
GBoostModel # Target hyperparameters and their ranges
= [
paramRanges range(GBoostModel, :max_depth, values = 3:6),
range(GBoostModel, :nrounds, values = 200:10:300),
]# Model wrapper for optimization
= TunedModel(
selfTuningGBoost = GBoostModel,
model = CV(),
resampling = Grid(),
tuning = paramRanges,
range = Accuracy(),
measure = false,
train_best
)= machine(selfTuningGBoost, lowInputs, outputs)
machGBoost fit!(machGBoost, verbosity = 0)
MLJ.# Evaluate model with best hyperparameters
= fitted_params(machGBoost).best_model
bestGBoost evaluate(bestGBoost, lowInputs, outputs; measure = Accuracy(), verbosity = 0) MLJ.
PerformanceEvaluation object with these fields:
model, measure, operation,
measurement, per_fold, per_observation,
fitted_params_per_fold, report_per_fold,
train_test_rows, resampling, repeats
Extract:
┌────────────┬──────────────┬─────────────┐
│ measure │ operation │ measurement │
├────────────┼──────────────┼─────────────┤
│ Accuracy() │ predict_mode │ 0.742 │
└────────────┴──────────────┴─────────────┘
┌────────────────────────────────────┬─────────┐
│ per_fold │ 1.96*SE │
├────────────────────────────────────┼─────────┤
│ [0.727, 0.727, 0.6, 0.9, 0.7, 0.8] │ 0.0883 │
└────────────────────────────────────┴─────────┘
In the previous code block, evaluate
returns a PerformanceEvaluation
object, which created the output shown. The ‘per_fold’ field includes the average values of the measure in the validations for different data splits. The measure used was accuracy, the rate of correct predictions. The ’1.96*SE’ field is the standard error of the validation errors, which can be useful for defining confidence intervals on the validation performance, and comparing models. And lastly, the ‘measurement’ field is the average of the ‘per_fold’ values, after weighting based on the validation split size. Lastly, it’s worth noting the variation that data splitting causes on the validation accuracy.
Going back to the TunedModel
definition, another setting is the resampling strategy in resampling = CV()
. Back in Section 3.3, the reason for splitting the data between training and validation is to avoid overfitting (Google 2024): adjusting the model too tightly to the data distribution.
It’s important to avoid such a phenomenon, because data includes a (hopefully not dominant) random component, which is not informative of the dynamic the data represents. And, if the model gets to the point of ‘memorizing’ the training data, it won’t generalize well, resulting in a considerable performance drop when facing new cases.
Hence it’s standard practice in ML to split the data between training and validation (and sometimes also for testing), so one tracks not only the performance improvement during training, but also the validation behavior. There should be an initial phase during which both training and validation metrics improve. Afterwards however, validation performance starts dropping, indicating overfitting. Figure 7 shows the performance trends in the training and validation splits for different gradient boost ensembles.
# Split data
= partition(shuffle(MersenneTwister(30), 1:nrow(lowInputs)), 0.8)
trainIndices, valIndices # Range for number of trees in ensembles
= 1:3:500
r # Store accuracies in both splits for each ensemble size
= Dict("train" => zeros(length(r)), "val" => zeros(length(r)))
accs for (i, nrounds) in enumerate(r)
# Instantiate and fit current ensemble
= EvoTreeClassifier(; nrounds)
overModel = machine(overModel, lowInputs[trainIndices, :], outputs[trainIndices])
overM fit!(overM; verbosity = 0)
MLJ.= MLJ.ConfusionMatrix(; levels = [1, 2])
cmOver # Accuracy of fitted model in training split
= MLJ.predict_mode(overM, lowInputs[trainIndices, :])
predictionsOver = cmOver(MLJ.levelcode.(predictionsOver), MLJ.levelcode.(outputs[trainIndices]))
cmOver "train"][i] = Accuracy()(cmOver)
accs[# Accuracy of fitted model in validation split
= MLJ.ConfusionMatrix(; levels = [1, 2])
cmOver = MLJ.predict_mode(overM, lowInputs[valIndices, :])
predictionsOver = cmOver(MLJ.levelcode.(predictionsOver), MLJ.levelcode.(outputs[valIndices]))
cmOver "val"][i] = Accuracy()(cmOver)
accs[end
# Plotting
= Figure(; fontsize = 18)
fig = Axis(fig[1, 1], xlabel = "Number of trees", ylabel = "Accuracy")
ax = lines!(r, accs["train"])
lTrain = lines!(r, accs["val"])
lVal Legend( # Labels for splits
1, 2],
fig[
[lTrain, lVal],"Training", "Validation"],
[
)save("./overfit.png", fig)
The y axis indicates the accuracy: the proportion of predictions that are correct. The higher this value, the better, of course. And the x axis indicates the number of trees included in the ensemble. Increasing this number expands the model’s capacity (Ian Goodfellow 2016): the diversity of mappings it can represent. If it’s low, the model underfits the training data and doesn’t perform as well as it could. If it’s too high, the model has enough flexibility to fit even the noise in the training set. Here, overfitting is identified by the validation performance reaching a peak, and then diminishing again.
Further, aside from manually splitting the dataset (“holdout validation”), a common strategy in the context of hyperparameter search is cross-validation (CV), which Figure 8 illustrates:
- Split the data in \(k\) ‘folds’ of equal size.
- Iterate on the \(k\) folds:
- Assign the fold for validation.
- Set the model hyperparameters.
- Create a new instance of the model.
- Train the current model instance on the training folds.
- Validate the current model instance on the validation fold.
- Log the performance
- Use a strategy to choose the hyperparameters according to the validation errors.
K-fold cross-validation is sometimes also used without changing hyperparameters between iterations. That’s useful to get a broader perspective on the model’s performance, averaging out the impact of what data points are used in training. This is done in MLJ’s evaluate()
function.
Furthermore, regarding the number of folds K (Kohavi 2001), values from 5 to 10 are common. Higher values increase the computational cost of the procedure, but reduce the bias when estimating the model’s performance. Lower values will use less data for training, so small models may be adequate, but large models may overfit.
Going back to the gradient boosted TunedModel
, after resampling = CV()
, there is also the tuning = Grid()
setting, specifying the strategy used to update the hyperparameters across iterations. Grid
is the most simple option in MLJ.jl, simply testing a certain number of equally-spaced values. This strategy quickly becomes underwhelming in demanding applications, since it’s mostly a trial-and-error exploration of a few possibilities, instead of an efficient search algorithm (Run:ai 2024).
Those values are specified by the next setting, range = paramRanges
. It specifies the parameters to be optimized, and what range of values should be considered. Here, the max_depth
parameter was optimized in the range 3:6
, controlling the maximum depth allowed for the decision trees that underlie the gradient boosting ensemble. Additionally, the parameter nrounds
was optimized in the range 200:10:300
, determining the number of trees.
After fitting the TunedModel
, a model with the optimal hyperparameters can be extracted and used. Below, such a model is fitted to the training split of manually partitioned data. Later on, it will be used to make predictions on the validation set.
# Manual data split
= partition(1:nrow(lowInputs), 0.8)
trainIndices, valIndices # Bind model to training split
= machine(bestGBoost, lowInputs[trainIndices, :], outputs[trainIndices])
machBest fit!(machBest, verbosity = 0); MLJ.
4.2 Model evaluation
In ML, there are usually numerous options to perform any step in the project, from data preparation to performance evaluation. As an example, there are various metrics for the current case of binary classification.
Despite this, the confusion matrix enables a handful of useful options. It is built by, after making predictions with the model, arranging in a matrix how many cases fall in each of the following categories:
- True positive (TP): a positive case (responder) that was correctly classified as positive.
- True negative (TN): a negative case (non-responder) that was correctly classified as negative.
- False positive (FP, type I error): a negative case that was misclassified as positive.
- False negative (FN, type II error): a positive case that was misclassified as negative.
Ground truth | ||
Predicted | 0 | 1 |
0 | 4 | 2 |
1 | 1 | 5 |
From these basic definitions, the following binary classification metrics can be calculated to prioritize different aspects of performance:
- Precision, positive predictive value (PPV): rate of correct positive classifications.
\[\frac{TP}{TP + FP}\]
- Recall, sensitivity or true positive rate: ratio of positive cases that are correctly classified.
\[\frac{TP}{TP + FN}\]
- F1 score: harmonic mean between precision and recall. It is only high if both previous metrics are high.
\[2 * \frac{(Precision * Recall)}{Precision + Recall}\]
- Accuracy: fraction of classifications that are correct.
\[\frac{TP + TN}{TP + TN + FP + FN}\]
= predict_mode(machBest, lowInputs[valIndices, :])
predictionsGBoost # Confusion matrix [TN FN; FP TP]
= MLJ.ConfusionMatrix(; levels = [1, 2])
cm = cm(MLJ.levelcode.(predictionsGBoost), MLJ.levelcode.(outputs[valIndices]))
cmGBoost # Metrics
= ppv(cmGBoost)
prec = recall(cmGBoost)
rec = FScore()(cmGBoost)
f = Accuracy()(cmGBoost); acc
The following table presents the validation metrics associated with the previous model:
Precision | 0.75 |
Recall | 0.86 |
F1 | 0.80 |
Accuracy | 0.75 |
4.3 K-nearest neighbors pipeline
The next model to be used is K-nearest neighbors (elastic 2024). It is a supervised approach that compares the input against the K nearest data points. In regression problems, it then averages the labels of these neighbors; and, in classification, it uses the labels of these nearby points in a voting system to determine the output. Here, the parameter K defines how many neighboring points will be considered, and its default is 5. Such model will be included in an MLJ pipeline, a functionality that enables very concise definitions of entire ML workflows. The code below recreates the procedures done so far, but now in a pipeline. First, functions are defined to include the coercing and one-hot encoding procedures. Then, these steps are included in a pipeline alongside standardization, imputing, PCA, and an instance of a k-nearest neighbors model.
# One-hot encode "Rutherford" and coerce other columns to 'continuous'
= inp -> coerce(inp, Symbol("Rutherford") => OrderedFactor)
ordFacCoerce = ContinuousEncoder(; one_hot_ordered_factors = true)
contCoerce # Pipeline model
=
modelPipe |>
FillImputer |>
ordFacCoerce |>
contCoerce |>
Standardizer PCA(; maxoutdim = 10) |>
knn(; K = 8)
= machine(modelPipe, df[:, colIDs], outputs)
machPipe fit!(machPipe, verbosity = 0)
# Shuffle rows to avoid one-hot encode problem with
# category imbalance in baseline "Rutherford".
= shuffle(MersenneTwister(9), 1:nrow(df))
shuffleRows evaluate(
MLJ.
modelPipe,
df[shuffleRows, colIDs],
outputs[shuffleRows];= [Accuracy(), cross_entropy],
measure = CV(; nfolds = 5),
resampling = 0,
verbosity )
PerformanceEvaluation object with these fields:
model, measure, operation,
measurement, per_fold, per_observation,
fitted_params_per_fold, report_per_fold,
train_test_rows, resampling, repeats
Extract:
┌───┬──────────────────────┬──────────────┬─────────────┐
│ │ measure │ operation │ measurement │
├───┼──────────────────────┼──────────────┼─────────────┤
│ A │ Accuracy() │ predict_mode │ 0.645 │
│ B │ LogLoss( │ predict │ 0.609 │
│ │ tol = 2.22045e-16) │ │ │
└───┴──────────────────────┴──────────────┴─────────────┘
┌───┬─────────────────────────────────────┬─────────┐
│ │ per_fold │ 1.96*SE │
├───┼─────────────────────────────────────┼─────────┤
│ A │ [0.846, 0.385, 0.917, 0.75, 0.333] │ 0.264 │
│ B │ [0.528, 0.875, 0.424, 0.456, 0.746] │ 0.192 │
└───┴─────────────────────────────────────┴─────────┘
The A |> B
operator in Julia is called piping, which uses the output of expression A
as the input to expression B
, clearly indicating information flow.
In the evaluate
call in the last code block, the metric cross_entropy
was provided as well. It refers to binary cross-entropy (Godoy 2018), a measure of dissimilarity between two distributions. After fitting, the outputs of the model should follow the same distribution of the original data. Therefore, binary cross-entropy can be used as an objective to be minimized for training binary classification models, leading to the same results as if maximizing the log-likelihood (Draelos 2019). In the following formula, \(N\) is the amount of binary variables being compared; and variable \(i\) assumes the value \(y_i\) with probability \(p(y_i)\).
\[CE = -\frac{1}{N} \sum^{N}_{i = 1} y_i \cdot log[p(y_i)] + (1 - y_i) \cdot log[1 - p(y_i)]\]
The model’s “machine” can then be saved to disk with MLJ.save("pipeModel.jld2", machPipe)
. It can then be loaded later, for example when starting a new session, for a future continuation of the project, building a model catalogue etc. Together with pipelines, in case of a new file with additional data, the model can be loaded with machLoaded = machine("./pipeModel.jld2")
, and used to quickly reapply the entire workflow discussed so far with pipePredictions = MLJ.predict(machLoaded, df[:, colIDs])
, directly on raw data, that has missing values and isn’t preprocessed.
Furthermore, with this model, the output for each case is a probability distribution across the two possibilities. Here, this refers to the chance of the patient responding to the treatment. Therefore, to assess the model’s performance, the area under the receiver operator curve (AUC) can be applied.
In short, when the model predicts a certain chance that a patient will respond to the treatment, a threshold (60% or 70% etc.) needs to be established, above which it is interpreted that the model considers the patient a responder. However, the threshold level impacts the metrics in the confusion matrix (Section 4.2), so the best value for a given context must be chosen.
The receiver operating characteristic (ROC) curve is commonly used for this. It plots the true positive rate (i.e. TPR, recall, sensitivity) in the y axis, and the false positive rate (F.P.R.) in the x axis. And, for each threshold level, the model’s predictions result in a (F.P.R., TPR) pair, representing a point in the plot and generating a curve.
A model that randomly guesses the outcome (equivalent to a coin flip), is graphically represented by a 45° line. And, ideally, the curve that results from the model isn’t only above such a straight line, but also approaches the top left corner, of unitary TPR and null F.P.R.. This can be measured through the area under the curve (AUC), which would then approach 1.
Finally, for the previous pipeline model, the ROC is included in Figure 9. The AUC is of 0.83, and the threshold levels used were [0.2, 0.4, 0.6, 0.8, 1.0]
.
= Figure(; fontsize = 18) # Create figure to draw in
fig = Axis(
axROC 1, 1];
fig[= "Receiver operating characteristic (ROC) curve",
title = "False positive rate",
xlabel = "True positive rate",
ylabel = 0:0.2:1,
xticks = 0:0.2:1,
yticks
)# MLJ function for ROC plot data
= roc_curve(pipePredictions, outputs)
rocD, tpr, thresholds = lines!(axROC, [0, 1], [0, 1]; linestyle = :dash) # 45° line
l45 = lines!(axROC, rocD, tpr) # Model
lModel Legend(fig[1, 2], [l45, lModel], ["Random guess", "kNN"])
save("./roc.png", fig) # Save PNG