Machine learning fundamentals

Authors

Lucas Pereira

Mohamed Tarek

Anthony Blaom

1 Introduction

In the context of the first unit in the “AI for drug development” course, this tutorial provides a didactic use of real data and study for the introduction of machine learning (ML) fundamentals, rather than robust clinical tools. The discussion here is guided towards professionals and academics from healthcare, specially pharmacometrics.

Furthermore, the workflow we present here assumes no prior knowledge about AI. The focus of each section is as follows:

  1. We begin by summarizing the context of the study below.
  2. In Section 2, common data-processing steps are presented, such as imputation, one-hot encoding, and standardization.
  3. In Section 3, the distinction between supervised and unsupervised application is explained. There, some unsupervised models are discussed as well, focusing in dimensionality reduction (through principal component analysis) and clustering (including K-Medoids and K-Means).
  4. In Section 4, we present decision trees, ensembles, random forest, and gradient-boosted trees. Also in that section, options for model evaluation in binary classification are detailed. Lastly, K-nearest neighbors is explained, and references for further study are listed.

In the case study analyzed, the goal is to determine the main predictors of response to a treatment against critical limb ischemia (CLI), a disease in advanced stage, and characterized by the blockage of arteries in a limb. The authors of “Characteristics of responders to autologous bone marrow cell therapy for no-option critical limb ischemia” (Madaric et al. 2016) made their dataset available as a spreadsheet, containing multiple features about the 62 patients that initially enrolled in the study. The endpoints of interest in the study are:

  1. Primary endpoints:
    1. Responder: surviving patients with limb salvage and wound healing.
    2. Non-responders: patients requiring major limb amputation or those with no signs of wound healing.
  2. Secondary endpoints:
    1. Mortality.
    2. Amputation-free survival.
    3. Major limb amputation.
    4. Change in tcpO2.
    5. Rutherford category.
    6. Pain scale after BMC transplantation.

And, from the point of view of ML, this leads to a classification problem. In other words, there is a finite amount of discrete possible outcomes. Further, the problem is specifically of binary classification: either a patient responded to the treatment or not. In contrast, a regression problem has a continuous outcome, such as forecasting energy demand or traffic congestion. Other high-level groupings of ML algorithms can be found at (Brownlee 2023).

It’s important to point out that a dataset with 62 data points is extremely small for ML applications. In general, datasets with hundreds of points are considered small. Still, the main concepts can be described, although the specific performance values obtained are somewhat didactic.

First, some Julia packages will be necessary. The MLJ.jl package will be used throughout the tutorial, but the ML concepts and techniques discussed here are tool-agnostic.

# Setup environment (only required once).
import Pkg
Pkg.add(
    [
        Pkg.PackageSpec(name = "KernelFunctions", version = v"0.10.64"),
        Pkg.PackageSpec(name = "CSV", version = v"0.10.15"),
        Pkg.PackageSpec(name = "PharmaDatasets", version = v"0.11.1"),
        Pkg.PackageSpec(name = "DataFrames", version = v"1.7.0"),
        Pkg.PackageSpec(name = "Distances", version = v"0.10.12"),
        Pkg.PackageSpec(name = "MLJ", version = v"0.20.7"),
        Pkg.PackageSpec(name = "Clustering", version = v"0.15.8"),
        Pkg.PackageSpec(name = "MultivariateStats", version = v"0.10.3"),
        Pkg.PackageSpec(name = "NearestNeighborModels", version = v"0.2.3"),
        Pkg.PackageSpec(name = "MLJClusteringInterface", version = v"0.1.13"),
        Pkg.PackageSpec(name = "MLJModels", version = v"0.17.6"),
        Pkg.PackageSpec(name = "EvoTrees", version = v"0.16.9"),
        Pkg.PackageSpec(name = "MLJMultivariateStatsInterface", version = v"0.5.3"),
    ]
)
# Bring packages to environment
using KernelFunctions, PharmaDatasets, Statistics
using MLJ, LinearAlgebra, CSV, DataFrames
using Random, Distances, Clustering, MultivariateStats
Random.seed!(378); # Set random seed for reproducibility
# Load models
KMedoids = @load KMedoids pkg = Clustering
EvoTreeClassifier = @load EvoTreeClassifier pkg = EvoTrees
knn = @load KNNClassifier pkg = NearestNeighborModels
KernelPCA = @load KernelPCA pkg = MultivariateStats
PCA = @load PCA pkg = MultivariateStats

2 Data processing

2.1 Identifying relevant input features and response variable

ML projects tend to follow very similar steps (Mehreen 2022). To start the analysis, the file is read. And some preprocessing is done. This includes conversion of occurrences of empty strings (““) into the value missing. In the data overview below, the nmissing column indicate the number of missing values for that dataset column. We’ll turn to these missing values shortly.

# Read spreadsheet data
df = CSV.read(dataset("CLIstemCell", String), DataFrame; missingstring = ["", "#NA!"])
show(df[:, Not(7)]; allcols = true)
62×46 DataFrame
 Row │ PatientID  HealedAmputFreeSurvival  Survival  LimbSalvage  AmputFreeSurvival  HealedLimbSalvage  tcpO2 at 6M  tcpO2 at 12M  Rutherford  Rutherford at 6M  Rutherford at 12M  Pain     Pain at 6M  Pain at 12M  Sex      Age      ArterialHypertension  DiabetesMellitus  Hyperlipidemia  Cholesterol  Obesity  Smoking  Buerger  Creatinine  CD34     MNC      Leu      CRP      FBG         AtorvaDosage  ABI      ABI at 6M   ABI at 12M  QoL Base  QoL at 6M  QoL at 12M  Wound size base  Wound size at 6M  Wound size at 12M  sICAM-1     sICAM-3     sEselectin  sPselectin  sPECAM-1    VEGF         ApplicationRoute
     │ Int64      Float64                  Float64   Float64?     Float64            Float64?           Int64?       Int64?        Float64     Int64?            Int64?             Float64  Int64?      Int64?       Float64  Float64  Float64               Float64           Float64         Float64      Float64  Float64  Float64  Float64     Float64  Float64  Float64  Float64  Float64?    Float64       Float64  Float64?    Float64?    Int64     Int64?     Int64?      Float64          Float64?          Float64?           Float64?    Float64?    Float64?    Float64?    Float64?    Float64?     Float64
─────┼──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │         1                      1.0       1.0          1.0                1.0                1.0           30            27         5.0                 3                  0      0.0           0            0      0.0     47.0                   1.0               1.0             0.0          4.1      0.0      1.0      0.0       101.0    53.76     6.04     7.57      7.1        3.3           10.0     1.09        1.2         1.2         50         70          70              2.0               1.0                0.0  missing     missing     missing     missing     missing     missing                   1.0
   2 │         2                      0.0       1.0          0.0                0.0                0.0      missing       missing         5.0           missing            missing      4.0     missing      missing      0.0     70.0                   1.0               1.0             1.0          5.4      1.0      0.0      0.0       160.0    12.65     3.42     8.0      22.0        2.8            0.0     0.8   missing     missing           50    missing     missing              6.0         missing            missing    missing     missing     missing     missing     missing     missing                   1.0
   3 │         3                      0.0       1.0          1.0                1.0                0.0           28            28         5.0                 5                  3      7.0           0            0      1.0     70.0                   1.0               1.0             0.0          3.9      0.0      0.0      0.0        88.0    18.05     5.64    10.7       3.4        4.11          10.0     1.76        1.11        1.11        35         50          50              3.0               4.0                6.0  missing     missing     missing     missing     missing     missing                   1.0
   4 │         4                      0.0       1.0          0.0                0.0                0.0      missing       missing         5.0           missing            missing      5.0     missing      missing      0.0     73.0                   0.0               1.0             0.0          3.6      0.0      0.0      0.0        91.0     8.08     4.04    18.86    319.0  missing              0.0     1.4   missing     missing           40    missing     missing              8.0         missing            missing    missing     missing     missing     missing     missing     missing                   1.0
   5 │         5                      1.0       1.0          1.0                1.0                1.0           43            27         5.0                 2                  2      0.0           0            0      0.0     64.0                   1.0               1.0             1.0          3.2      0.0      1.0      0.0        92.0    10.53     3.9      8.9      33.2        4.67          10.0     1.3         0.64        0.64        30         70          70              6.0               4.0                0.0  missing     missing     missing     missing     missing     missing                   0.0
   6 │         6                      1.0       1.0          1.0                1.0                1.0           36            29         5.0                 1                  1      3.0           0            0      0.0     56.0                   1.0               1.0             1.0          3.8      0.0      1.0      0.0        73.0    33.95     4.14     6.3       1.8        1.8           20.0     0.39        0.84        0.67        50         80          80              2.0               0.0                0.0     1281.05       89.91      114.88      300.78      244.45      293.901               1.0
   7 │         7                      0.0       1.0          0.0                0.0                0.0           30       missing         5.0                 5                  6      6.0           1      missing      0.0     79.0                   1.0               0.0             0.0          4.7      0.0      0.0      0.0        85.0    25.38     4.15    10.33    105.0        4.11          10.0     0.7         0.71  missing           40         60     missing              8.0              16.0          missing    missing     missing     missing     missing     missing         700.245               0.0
   8 │         8                      1.0       1.0          1.0                1.0                1.0           38            20         5.0                 3                  3      0.0           0            0      0.0     68.0                   1.0               1.0             1.0          5.8      0.0      1.0      0.0       120.0     5.42     4.88     6.81      8.3        4.18          20.0     0.62        0.58        0.58        60    missing          75              6.0               2.5                2.5      675.56      214.99      209.87      842.11      338.51      349.82                1.0
  ⋮  │     ⋮                 ⋮                ⋮           ⋮               ⋮                  ⋮               ⋮            ⋮            ⋮              ⋮                  ⋮             ⋮         ⋮            ⋮          ⋮        ⋮              ⋮                   ⋮                ⋮              ⋮          ⋮        ⋮        ⋮         ⋮          ⋮        ⋮        ⋮        ⋮         ⋮            ⋮           ⋮         ⋮           ⋮          ⋮          ⋮          ⋮              ⋮                ⋮                  ⋮              ⋮           ⋮           ⋮           ⋮           ⋮            ⋮              ⋮
  56 │        56                      1.0       1.0          1.0                1.0                1.0           10             5         5.0                 2                  2      7.0           1            2      0.0     43.0                   0.0               0.0             0.0          4.7      0.0      1.0      1.0        90.0    28.56     4.2      8.34      1.0        2.61          20.0     0.73  missing           0.72        70         70          70              2.0               0.0                0.0      534.47      138.65      145.25       79.24      221.39      140.031               1.0
  57 │        57                      1.0       1.0          1.0                1.0                1.0            4            13         5.0                 5                  3      1.0           1            0      1.0     57.0                   0.0               1.0             1.0          3.4      0.0      0.0      0.0        66.0    40.61     5.72     9.62     25.9        4.65          40.0     1.23        1.2         1.2         50         70          75              8.0               4.0                0.0      895.67      191.96      200.19     1002.39      346.05       56.602               0.0
  58 │        58                      0.0       1.0          0.0                0.0                0.0      missing       missing         5.0                 6                  6      5.0     missing      missing      0.0     71.0                   1.0               1.0             0.0          3.9      0.0      0.0      0.0       101.0    20.56     2.7      6.78      7.2        4.38          20.0     0.56  missing     missing           50    missing     missing              4.0         missing            missing    missing     missing     missing     missing     missing     missing                   0.0
  59 │        59                      1.0       1.0          1.0                1.0                1.0            7             7         5.0                 3                  3      6.0           2            2      0.0     43.0                   1.0               0.0             1.0          4.2      0.0      1.0      0.0        65.0    30.24     4.2      6.7       0.4        2.37          20.0     0.67  missing     missing           40         70          70              4.0               0.0                0.0      469.34      126.55      139.22      499.71      351.21       29.626               1.0
  60 │        60                      1.0       1.0          1.0                1.0                1.0           35            27         5.0                 5                  3      5.0           2            0      0.0     74.0                   1.0               1.0             1.0          4.7      0.0      0.0      0.0        86.0    35.78     5.04     7.69      2.0        3.14          20.0     0.6         0.6         0.61        80         70          85              2.0               1.0                0.0  missing     missing     missing     missing     missing     missing                   1.0
  61 │        61                      0.0       1.0          0.0                0.0                0.0      missing       missing         6.0                 6                  6      6.0     missing      missing      0.0     52.0                   1.0               0.0             1.0          3.9      0.0      1.0      0.0        74.0    38.38     5.73     9.97      2.5        2.59          10.0     0.38  missing     missing           50    missing     missing             22.0         missing            missing    missing     missing     missing     missing     missing     missing                   1.0
  62 │        62                      0.0       1.0          1.0                1.0                0.0           16            51         5.0                 5                  5      5.0           3            0      0.0     72.0                   1.0               1.0             1.0          3.2      0.0      0.0      0.0       104.0    13.15     3.37     8.61      1.3        3.47          80.0     1.29        1.08        1.1         80         60          70              2.0               6.0                6.0  missing     missing     missing     missing     missing     missing                   0.0
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     47 rows omitted

We now shuffle the order in which observations (rows) appear. Having done this, we can create simple observations splits of the data (e.g., into train and test sets) and have a better chance that each subset being a fair representative of the whole.

indices = 1:nrow(df)
shuffled = shuffle(MersenneTwister(37), indices)
df = df[shuffled, :];

Next, we split the data into inputs (a.k.a., features or covariates) and outputs (a.k.a., target or labels), dumping some input variables that won’t play any role in the current analysis:

# Subset of columns used as inputs
colIDs = vcat([7, 10, 13], 16:32, [35, 38], 41:46)
inputs = df[:, colIDs] # 62x28

# Primary endpoint as binary prediction target
outputs = collect(df[:, "HealedAmputFreeSurvival"]);

Our choice of colIDs includes most of the features but excludes most secondary endpoints. Other information not included were snapshots after baseline. Snapshots are features that were measured multiple times at different moments: right before treatment (baseline), and 6 and 12 months after treatment, namely tcpO2, ABI (ankle-brachial index), wound size, and Rutherford, pain, and quality of life scales.

Since the context of the ML models developed in this tutorial is to predict if a new patient suffering from CLI would benefit from the proposed treatment, only the baseline values of snapshot variables will be used. However, one can imagine using such models 6 months after treatment, to have a more precise prediction of further development of the patients condition in the next 6 months. The inclusion of the snapshots at 6 months is left as an exercise for the reader.

2.2 Coercing data types

Data representing the same “scientific type” is often encoded using a multitude of machine types. For example, a “number of legs” variable might be represented using Int (as in 0, 2, 4, …) as a Float64 (0.0, 2.0, 4.0, …) or as a String ("2", "4", "6", …; or "zero", "two", "four", …). Algorithms cannot determine the intended scientific type and will process data strictly according to its machine representation and some convention chosen by the implementation (such as “Float64 always represents a continuous variable”). To mitigate possible ambiguities the MLJ package does two things:

  • It assigns to each julia object a “scientific” type, or scitype, such as Continuous, Count, OrderedFactor, and Multiclass (unordered factor). This is distinct from ordinary machine types, like Float64. You can inspect the scitype using the scitype method.

  • It provides a method coerce for forcing data to have the scientific type you intend, rather than the one it has by circumstance.

For example, let’s check the scitype of our binary CLI response variable:

scitype(outputs)
AbstractVector{Continuous} (alias for AbstractArray{Continuous, 1})

which is not how we want our response interpreted. Since the response is a binary variable with a conventional “positive” class (here represented as 1.0) we will coerce to OrderedFactor. (When there is no preferred “positive” class - as in a person’s gender - we use Multiclass instead).

# fix scitype of `outputs`:
outputs = coerce(outputs, OrderedFactor)
scitype(outputs)
AbstractVector{OrderedFactor{2}} (alias for AbstractArray{OrderedFactor{2}, 1})

The first class, as seen by the levels function, will be interpreted as “negative” class. Use levels! to reorder classes. This is unnecessary here:

levels(outputs) |> println
[0.0, 1.0]

To simultaneously inspect all column scitypes in a table, MLJ provides the schema method:

schema(inputs) |> DataFrame |> println
28×3 DataFrame
 Row │ names                 scitypes                    types                 ⋯
     │ Symbol                Type                        Type                  ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ tcpO2                 Continuous                  Float64               ⋯
   2 │ Rutherford            Continuous                  Float64
   3 │ Pain                  Continuous                  Float64
   4 │ Sex                   Continuous                  Float64
   5 │ Age                   Continuous                  Float64               ⋯
   6 │ ArterialHypertension  Continuous                  Float64
   7 │ DiabetesMellitus      Continuous                  Float64
   8 │ Hyperlipidemia        Continuous                  Float64
  ⋮  │          ⋮                        ⋮                          ⋮          ⋱
  22 │ Wound size base       Continuous                  Float64               ⋯
  23 │ sICAM-1               Union{Missing, Continuous}  Union{Missing, Float6
  24 │ sICAM-3               Union{Missing, Continuous}  Union{Missing, Float6
  25 │ sEselectin            Union{Missing, Continuous}  Union{Missing, Float6
  26 │ sPselectin            Union{Missing, Continuous}  Union{Missing, Float6 ⋯
  27 │ sPECAM-1              Union{Missing, Continuous}  Union{Missing, Float6
  28 │ VEGF                  Union{Missing, Continuous}  Union{Missing, Float6
                                                    1 column and 13 rows omitted

Note that MLJ also views Missing as a scitype. We can change the scitype of the Rutherford input feature (measuring, on the critical limb ischemia Rutherford scale, severity of the disease) like this:

inputs = coerce(inputs, :Rutherford => OrderedFactor)
# check scitypes
scitype(inputs.Rutherford)
AbstractVector{OrderedFactor{2}} (alias for AbstractArray{OrderedFactor{2}, 1})

The QoL Base variable should be Continuous instead of Count, which we fix like this:

inputs = coerce(inputs, Symbol("QoL Base") => Continuous)
# final scitype check:
schema(inputs) |> DataFrame |> println
28×3 DataFrame
 Row │ names                 scitypes                    types                 ⋯
     │ Symbol                Type                        Type                  ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ tcpO2                 Continuous                  Float64               ⋯
   2 │ Rutherford            OrderedFactor{2}            CategoricalValue{Floa
   3 │ Pain                  Continuous                  Float64
   4 │ Sex                   Continuous                  Float64
   5 │ Age                   Continuous                  Float64               ⋯
   6 │ ArterialHypertension  Continuous                  Float64
   7 │ DiabetesMellitus      Continuous                  Float64
   8 │ Hyperlipidemia        Continuous                  Float64
  ⋮  │          ⋮                        ⋮                               ⋮     ⋱
  22 │ Wound size base       Continuous                  Float64               ⋯
  23 │ sICAM-1               Union{Missing, Continuous}  Union{Missing, Float6
  24 │ sICAM-3               Union{Missing, Continuous}  Union{Missing, Float6
  25 │ sEselectin            Union{Missing, Continuous}  Union{Missing, Float6
  26 │ sPselectin            Union{Missing, Continuous}  Union{Missing, Float6 ⋯
  27 │ sPECAM-1              Union{Missing, Continuous}  Union{Missing, Float6
  28 │ VEGF                  Union{Missing, Continuous}  Union{Missing, Float6
                                                    1 column and 13 rows omitted

2.3 One-hot encoding

A few supervised models can handle categorical inputs, i.e., those that assume a finite set of possible discrete values, and which generally have OrderedFactor or Multiclass scitypes. However, some only support Continuous inputs. For this reason, we apply one-hot encoding to convert our OrderedFactor variable Rutherford to a new Continuous variable:

# Coerce data types of remaining columns
mach = machine(OneHotEncoder(; ordered_factor = true, drop_last = true), inputs)
fit!(mach, verbosity = 0)
inputs = MLJ.transform(mach, inputs) # 62x28
schema(inputs) |> DataFrame |> println
28×3 DataFrame
 Row │ names                 scitypes                    types                 ⋯
     │ Symbol                Type                        Type                  ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ tcpO2                 Continuous                  Float64               ⋯
   2 │ Rutherford__5.0       Continuous                  Float64
   3 │ Pain                  Continuous                  Float64
   4 │ Sex                   Continuous                  Float64
   5 │ Age                   Continuous                  Float64               ⋯
   6 │ ArterialHypertension  Continuous                  Float64
   7 │ DiabetesMellitus      Continuous                  Float64
   8 │ Hyperlipidemia        Continuous                  Float64
  ⋮  │          ⋮                        ⋮                          ⋮          ⋱
  22 │ Wound size base       Continuous                  Float64               ⋯
  23 │ sICAM-1               Union{Missing, Continuous}  Union{Missing, Float6
  24 │ sICAM-3               Union{Missing, Continuous}  Union{Missing, Float6
  25 │ sEselectin            Union{Missing, Continuous}  Union{Missing, Float6
  26 │ sPselectin            Union{Missing, Continuous}  Union{Missing, Float6 ⋯
  27 │ sPECAM-1              Union{Missing, Continuous}  Union{Missing, Float6
  28 │ VEGF                  Union{Missing, Continuous}  Union{Missing, Float6
                                                    1 column and 13 rows omitted

Note that the Rutherford column was replaced by another called Rutherford__5.0. That happened because the keyword argument drop_last was set to true, causing ‘reference encoding’ to be used, instead of the standard procedure. This means the number of binary variables created is one less than the number of categories, and (baseline) Rutherford only had categories 5 and 6, since the study focused in critical cases of the disease.

As an aside, let’s test how standard one-hot encoding would work by taking a look at an input feature that we excluded, Rutherford at 6M. This measures, on the critical limb ischemia Rutherford scale, severity of the disease after 6 months. While the baseline Rutherford measurement included above takes only two possible values, Rutherford at 6M assumes all six possible values on the scale. Let’s see how these get one-hot encoded. We’ll first create a mini-dataframe that includes this variable and removes missing values with a procedure explained later:

df_small = df[:, [11, 13]] # "Rutherford at 6M" and "Pain" columns (62x2)
df_small = coerce(df_small, Symbol("Rutherford at 6M") => OrderedFactor)
mach = machine(FillImputer(), df_small)
fit!(mach, verbosity=0)
df_small = MLJ.transform(mach, df_small)
first(df_small, 10) |> println
Info: Trying to coerce from `Union{Missing, Int64}` to `ScientificTypesBase.OrderedFactor`.
Coerced to `Union{Missing,ScientificTypesBase.OrderedFactor}` instead.
10×2 DataFrame
 Row │ Rutherford at 6M  Pain
     │ Categorical…      Float64
─────┼───────────────────────────
   1 │ 5                     4.0
   2 │ 2                     0.0
   3 │ 6                     9.0
   4 │ 5                     4.0
   5 │ 5                     1.0
   6 │ 5                     4.0
   7 │ 2                     0.0
   8 │ 5                     5.0
   9 │ 5                     7.0
  10 │ 6                     5.0

Now for the encoding:

mach = machine(OneHotEncoder(; ordered_factor = true), df_small)
fit!(mach, verbosity=0)
df_hot = MLJ.transform(mach, df_small)
show(df_hot; allcols = true)
62×7 DataFrame
 Row │ Rutherford at 6M__1  Rutherford at 6M__2  Rutherford at 6M__3  Rutherford at 6M__4  Rutherford at 6M__5  Rutherford at 6M__6  Pain
     │ Float64              Float64              Float64              Float64              Float64              Float64              Float64
─────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      4.0
   2 │                 0.0                  1.0                  0.0                  0.0                  0.0                  0.0      0.0
   3 │                 0.0                  0.0                  0.0                  0.0                  0.0                  1.0      9.0
   4 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      4.0
   5 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      1.0
   6 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      4.0
   7 │                 0.0                  1.0                  0.0                  0.0                  0.0                  0.0      0.0
   8 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      5.0
  ⋮  │          ⋮                    ⋮                    ⋮                    ⋮                    ⋮                    ⋮              ⋮
  56 │                 0.0                  0.0                  1.0                  0.0                  0.0                  0.0      0.0
  57 │                 0.0                  0.0                  0.0                  0.0                  0.0                  1.0      6.0
  58 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      6.0
  59 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0      7.0
  60 │                 0.0                  0.0                  0.0                  0.0                  0.0                  1.0      6.0
  61 │                 1.0                  0.0                  0.0                  0.0                  0.0                  0.0      3.0
  62 │                 0.0                  0.0                  0.0                  0.0                  1.0                  0.0     10.0
                                                                                                                              47 rows omitted

We see that one-hot encoding turns this categorical column of 6 levels into 6 binary variables: Rutherford at 6M__i, for i from 1 to the number of unique values of the categorical variable. Across these new binary variables, each patient has 0 in all of them except one, which is 1 and indicates the value in the original categorical column.

2.4 Imputation

The next data-processing step is called imputation, assigning a value to all missings. For this, the MLJ package has the FillImputer() utility, which, by default, uses the median or the mode per column according to the type of data. This is a very simple approach for imputation, and many alternatives exist (Airbyte 2024).

# Impute missing values by median and mode
mach = machine(FillImputer(), inputs) # Attach model to data
MLJ.fit!(mach, verbosity = 0)
imputedInputs = MLJ.transform(mach, inputs) # Perform imputation
show(imputedInputs[:, Not(7)]; allcols = true) # Overview
62×27 DataFrame
 Row │ tcpO2    Rutherford__5.0  Pain     Sex      Age      ArterialHypertension  Hyperlipidemia  Cholesterol  Obesity  Smoking  Buerger  Creatinine  CD34     MNC      Leu      CRP      FBG      AtorvaDosage  ABI      QoL Base  Wound size base  sICAM-1  sICAM-3  sEselectin  sPselectin  sPECAM-1  VEGF
     │ Float64  Float64          Float64  Float64  Float64  Float64               Float64         Float64      Float64  Float64  Float64  Float64     Float64  Float64  Float64  Float64  Float64  Float64       Float64  Float64   Float64          Float64  Float64  Float64     Float64     Float64   Float64
─────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │    26.0              1.0      4.0      0.0     58.0                   0.0             1.0          4.8      0.0      1.0      1.0        73.0    48.12     4.54    11.16      4.3     3.9           20.0     0.69      90.0              4.0   493.31    75.52      171.82     1402.54    336.04   28.795
   2 │    30.0              1.0      0.0      0.0     63.0                   1.0             1.0          6.0      0.0      0.0      0.0        92.0    32.89     4.98    14.51      0.4     2.6           10.0     0.75      50.0              2.0   457.96   206.34      100.59      798.98    393.13  709.564
   3 │    10.0              1.0      9.0      0.0     54.0                   1.0             1.0          6.2      0.0      0.0      0.0        90.0    46.37     5.01    10.67     14.8     4.7           10.0     1.36      50.0              8.0   580.7     42.86       75.17      542.72    293.08  521.304
   4 │     4.0              1.0      4.0      0.0     70.0                   1.0             1.0          5.4      1.0      0.0      0.0       160.0    12.65     3.42     8.0      22.0     2.8            0.0     0.8       50.0              6.0   622.97   105.89      128.04      640.11    379.9    68.239
   5 │    20.0              1.0      1.0      1.0     57.0                   0.0             1.0          3.4      0.0      0.0      0.0        66.0    40.61     5.72     9.62     25.9     4.65          40.0     1.23      50.0              8.0   895.67   191.96      200.19     1002.39    346.05   56.602
   6 │     0.0              1.0      4.0      1.0     42.0                   0.0             0.0          4.2      0.0      1.0      1.0        43.0    35.13     4.95    10.12      3.2     4.51           0.0     0.69      50.0              6.0  1281.05   123.65      128.04      771.58    405.4    34.303
   7 │    21.0              0.0      0.0      0.0     62.0                   1.0             1.0          5.3      0.0      0.0      0.0        93.0    73.73     8.19     8.2       5.7     3.58          10.0     1.05      60.0              4.5   741.64   197.97      235.18     1109.87    516.25  463.522
   8 │    11.0              1.0      5.0      0.0     49.0                   1.0             1.0          4.5      0.0      1.0      0.0        66.0    21.6      5.4     11.8      48.2     5.89           0.0     1.14      50.0              8.0  1593.84    65.39      225.35     1109.87    393.13  418.787
  ⋮  │    ⋮            ⋮            ⋮        ⋮        ⋮              ⋮                  ⋮              ⋮          ⋮        ⋮        ⋮         ⋮          ⋮        ⋮        ⋮        ⋮        ⋮          ⋮           ⋮        ⋮             ⋮            ⋮        ⋮         ⋮           ⋮          ⋮         ⋮
  56 │    11.0              1.0      0.0      0.0     47.0                   1.0             0.0          4.1      0.0      1.0      0.0       101.0    53.76     6.04     7.57      7.1     3.3           10.0     1.09      50.0              2.0   622.97   105.89      128.04      640.11    379.9    68.239
  57 │     1.0              0.0      6.0      0.0     52.0                   1.0             1.0          3.9      0.0      1.0      0.0        74.0    38.38     5.73     9.97      2.5     2.59          10.0     0.38      50.0             22.0   622.97   105.89      128.04      640.11    379.9    68.239
  58 │    20.0              1.0      6.0      0.0     75.0                   0.0             1.0          4.3      0.0      0.0      0.0       118.0    16.62     1.93     2.59     14.5     3.7           20.0     0.8       50.0              4.0   277.87    62.21       63.58      377.19    231.41  519.44
  59 │     1.0              1.0      7.0      0.0     67.0                   1.0             0.0          3.9      1.0      0.0      0.0        90.0    29.38     4.45     6.39      1.0     3.82          20.0     0.59      80.0              7.0   622.97   116.62       55.85      618.44    312.53   17.197
  60 │     7.0              1.0      6.0      1.0     64.0                   1.0             1.0          5.9      0.0      0.0      0.0        66.0    21.6      3.04     4.79     20.7     4.12          20.0     0.0       50.0              6.0   622.97   105.89      128.04      640.11    379.9    68.239
  61 │     4.0              1.0      3.0      0.0     56.0                   1.0             1.0          3.8      0.0      1.0      0.0        73.0    33.95     4.14     6.3       1.8     1.8           20.0     0.39      50.0              2.0  1281.05    89.91      114.88      300.78    244.45  293.901
  62 │     5.0              1.0     10.0      0.0     65.0                   1.0             1.0          6.1      0.0      1.0      0.0        48.0    31.42     5.93     8.18     43.6     4.68          40.0     1.7       40.0              6.0   394.16    66.86      102.66     1147.32    570.13   48.379
                                                                                                                                                                                                                                                                                                  47 rows omitted

Note that, despite the nmissing column now only containing zeros, the means and medians did not change much from the previous step. This is a good sign, indicating that the imputation process didn’t have a significant impact in the original data distribution. Now, as an exploration step, the histograms in Figure 1 can now be plotted to get more familiar with the data.

Figure 1: Histograms of selected variables. Leu - leukocyte level in peripheral blood.
# Activate plotting backend
GLMakie.activate!()
fig = Figure(; fontsize = 18) # Create figure to draw in
# Axis and plot for each histogram
ax = Axis(fig[1, 1]; title = "Age (y)")
hist!(ax, imputedInputs[:, "Age"])
ax = Axis(fig[1, 2]; title = "Cholesterol (mmol/L)")
hist!(ax, imputedInputs[:, "Cholesterol"])
ax = Axis(fig[2, 1]; title = "Leu (10⁹/L)")
hist!(ax, imputedInputs[:, "Leu"])
ax = Axis(fig[2, 2]; title = "Creatinine (μmol/L)")
hist!(ax, imputedInputs[:, "Creatinine"])
save("./case1_hists.svg", fig) # Save PNG

Now that we have removed missing values and categorical features, there are a large number of MLJ supervised learners available to us, which we can list like this:

# List of available models in MLJ.jl
models(matching(imputedInputs, outputs)) .|> println;
(name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )
(name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )
(name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )
(name = BayesianLDA, package_name = MultivariateStats, ... )
(name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )
(name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )
(name = CatBoostClassifier, package_name = CatBoost, ... )
(name = ConstantClassifier, package_name = MLJModels, ... )
(name = DecisionTreeClassifier, package_name = BetaML, ... )
(name = DecisionTreeClassifier, package_name = DecisionTree, ... )
(name = DeterministicConstantClassifier, package_name = MLJModels, ... )
(name = DummyClassifier, package_name = MLJScikitLearnInterface, ... )
(name = EvoTreeClassifier, package_name = EvoTrees, ... )
(name = ExtraTreesClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GaussianNBClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GaussianNBClassifier, package_name = NaiveBayes, ... )
(name = GaussianProcessClassifier, package_name = MLJScikitLearnInterface, ... )
(name = GradientBoostingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = HistGradientBoostingClassifier, package_name = MLJScikitLearnInterface, ... )
(name = KNNClassifier, package_name = NearestNeighborModels, ... )
(name = KNeighborsClassifier, package_name = MLJScikitLearnInterface, ... )
(name = KernelPerceptronClassifier, package_name = BetaML, ... )
(name = LDA, package_name = MultivariateStats, ... )
(name = LGBMClassifier, package_name = LightGBM, ... )
(name = LinearBinaryClassifier, package_name = GLM, ... )
(name = LinearSVC, package_name = LIBSVM, ... )
(name = LogisticCVClassifier, package_name = MLJScikitLearnInterface, ... )
(name = LogisticClassifier, package_name = MLJLinearModels, ... )
(name = LogisticClassifier, package_name = MLJScikitLearnInterface, ... )
(name = MaxnetBinaryClassifier, package_name = Maxnet, ... )
(name = MultinomialClassifier, package_name = MLJLinearModels, ... )
(name = NeuralNetworkBinaryClassifier, package_name = MLJFlux, ... )
(name = NeuralNetworkClassifier, package_name = BetaML, ... )
(name = NeuralNetworkClassifier, package_name = MLJFlux, ... )
(name = NuSVC, package_name = LIBSVM, ... )
(name = PassiveAggressiveClassifier, package_name = MLJScikitLearnInterface, ... )
(name = PegasosClassifier, package_name = BetaML, ... )
(name = PerceptronClassifier, package_name = BetaML, ... )
(name = PerceptronClassifier, package_name = MLJScikitLearnInterface, ... )
(name = ProbabilisticNuSVC, package_name = LIBSVM, ... )
(name = ProbabilisticSGDClassifier, package_name = MLJScikitLearnInterface, ... )
(name = ProbabilisticSVC, package_name = LIBSVM, ... )
(name = RandomForestClassifier, package_name = BetaML, ... )
(name = RandomForestClassifier, package_name = DecisionTree, ... )
(name = RandomForestClassifier, package_name = MLJScikitLearnInterface, ... )
(name = RidgeCVClassifier, package_name = MLJScikitLearnInterface, ... )
(name = RidgeClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SVC, package_name = LIBSVM, ... )
(name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )
(name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )
(name = StableForestClassifier, package_name = SIRUS, ... )
(name = StableRulesClassifier, package_name = SIRUS, ... )
(name = SubspaceLDA, package_name = MultivariateStats, ... )
(name = XGBoostClassifier, package_name = XGBoost, ... )

2.5 Standardization

Afterwards, each column is standardized, which means subtracting the mean and dividing by the standard deviation.

\[ \bar{x_{i}} = \frac{x_{i} - \mathbb{E}[x]}{\sqrt{Var[x]}} \]

As a result, all inputs have null mean and unitary standard deviation. This alleviates some numerical problems that may show up during training. An alternative is quantile normalization (Bolstad et al. 2003), which can be achieved through MLJ’s UnivariateDiscretizer.

# Standardization
mach = machine(Standardizer(), imputedInputs)
# Columns' mean and std calculated through 'fitting'
MLJ.fit!(mach, verbosity = 0)
# Use established mean and std to transform data
stdInputs = MLJ.transform(mach, imputedInputs)
# Check transformation
stdCols = eachcol(stdInputs)
DataFrame(
    Feature = names(stdInputs),
    Min = minimum.(stdCols),
    Mean = mean.(stdCols),
    Max = maximum.(stdCols),
    SD = std.(stdCols),
) |> println;
28×5 DataFrame
 Row │ Feature               Min        Mean          Max       SD
     │ String                Float64    Float64       Float64   Float64
─────┼──────────────────────────────────────────────────────────────────
   1 │ tcpO2                 -1.42178   -6.80459e-17  1.94898       1.0
   2 │ Rutherford__5.0       -3.03031    8.23714e-17  0.324676      1.0
   3 │ Pain                  -2.00892   -9.31155e-17  2.22758       1.0
   4 │ Sex                   -0.381784  -1.61161e-17  2.57704       1.0
   5 │ Age                   -2.39604    4.59086e-16  2.27203       1.0
   6 │ ArterialHypertension  -1.83665    5.73018e-17  0.535689      1.0
   7 │ DiabetesMellitus      -1.38596    1.14604e-16  0.709883      1.0
   8 │ Hyperlipidemia        -1.02443    5.37205e-17  0.960406      1.0
  ⋮  │          ⋮                ⋮           ⋮           ⋮         ⋮
  22 │ Wound size base       -0.762213  -3.60375e-17  3.62204       1.0
  23 │ sICAM-1               -1.12333   -5.10344e-16  6.19511       1.0
  24 │ sICAM-3               -1.85742   -3.27695e-16  4.26157       1.0
  25 │ sEselectin            -1.07723   -6.44646e-17  6.32657       1.0
  26 │ sPselectin            -1.27739   -1.88022e-16  6.17412       1.0
  27 │ sPECAM-1              -1.45862   -6.91203e-16  3.28237       1.0
  28 │ VEGF                  -0.770142  -2.32789e-17  3.14753       1.0
                                                         13 rows omitted

3 Unsupervised learning

In ML, each data point is usually an \((input, label)\) pair. When a model processes the input value, it’s expected that it will output the \(label\) value, or something at least close to it.

This rationale defines supervised ML applications, ones with a label associated with each data point. Common examples of pairs are \((image, category)\), or, for the current problem, \((patient \: data, treatment \: response)\).

On the other hand, the ML subfield of unsupervised learning builds models for data without labels. Its main uses are to gain insights about the data and learn patterns, instead of making predictions.

3.1 Dimensionality reduction

And, inside unsupervised learning, dimensionality reduction (Jacob Murel 2024) can be used to reduce the number of dimensions of the data. This is useful to (Géron 2019):

  1. Avoid the curse of dimensionality.
  2. Compress the information in the data.
  3. Visualize the dataset.
  4. Speed up training.

A great example of an unsupervised learning technique is Principal Component Analysis (Jaadi 2024). It projects the data points to a space of lower dimension, searching for orthogonal axes that account for as much of the data’s variance as possible. Figure 2 illustrates the information retention in PCA, as more principal components are included.

Figure 2: Information retained in PCA depending on number of principal components kept.
matIn = Matrix(stdInputs) # Inputs as matrix
pca = MultivariateStats.fit(MultivariateStats.PCA, matIn'; maxoutdim = 31)
valsP = eigvals(pca) # Eigenvalues
fig = Figure(; size = (800, 500), fontsize = 18)
ax = Axis(
    fig[1, 1],
    xlabel = "Number of principal components",
    ylabel = "Percentage of variance",
    yticks = 0:10:100,
    xticks = 1:2:31,
    ytickformat = "{:d}%",
)
cumulativePercentage = cumsum(valsP ./ sum(valsP)) .* 100
lines!(ax, cumulativePercentage)
save("./eigVar.svg", fig)

In the current case, the patient vectors from the previous step can be mathematically interpreted as points in a 31-dimensional space. In the following, PCA is used to visualize the scatter plot in Figure 3 of the data after projecting the points to 3-dimensional space.

mach = machine(PCA(; maxoutdim = 3), stdInputs)
MLJ.fit!(mach, verbosity = 0) # Compute principal components
# Project data to lower dimension
lowInputsViz = MLJ.transform(mach, stdInputs);
schema(lowInputsViz) # check dimension
┌───────┬────────────┬─────────┐
│ names │ scitypes   │ types   │
├───────┼────────────┼─────────┤
│ x1    │ Continuous │ Float64 │
│ x2    │ Continuous │ Float64 │
│ x3    │ Continuous │ Float64 │
└───────┴────────────┴─────────┘
Figure 3: Scatter plot of 3-dimensional projection of data. Colors tied to x axis.
# low-dimensional visualization
using CairoMakie
machPCAviz = machine(PCA(; maxoutdim = 3), stdInputs)
MLJ.fit!(machPCAviz) # Compute principal components
# Project data to lower dimension
lowInputsViz = MLJ.transform(machPCAviz, stdInputs)
pcaFig = Figure(; fontsize = 18)
ax = Axis3(pcaFig[1, 1])
scatter!(
  eachcol(lowInputsViz)...;
  colormap = :thermal,
  color = lowInputsViz[:, 1],
)
save("./PCA.svg", pcaFig)

A limitation of standard PCA is the assumption of linearity between features. This problem is solved in Kernel PCA (Thompson 2018), which first uses a ‘kernel function’ to map the data to a mathematical space of higher dimension, usually recovering the linear separability of the points.

As a side example, the data could be projected from 29 to 10 dimensions using a squared-exponential kernel. Further, the mean squared reconstruction error is calculated to be 0.84, indicating that the dimension reduction is representing the data well.

mach = machine(
    KernelPCA(; maxoutdim = 10, kernel = (x, y) -> SqExponentialKernel()(x, y)),
    stdInputs,
)
MLJ.fit!(mach, verbosity = 0)
KPCA_low_inputs = MLJ.transform(mach, stdInputs)
# Reconstruction
reconst = inverse_transform(mach, KPCA_low_inputs) |> DataFrame |> Matrix
# Mean squared reconstruction error
reconstError = (Matrix(stdInputs) .- reconst) .^ 2 |> mean
0.8394471407044392

However, for the current case, PCA leads to a mean squared reconstruction error of 0.29. As a result, results from PCA dimensionality reduction will be used for the remainder of the tutorial.

mach = machine(PCA(; maxoutdim = 10), stdInputs)
MLJ.fit!(mach, verbosity = 0)
lowInputs = MLJ.transform(mach, stdInputs)
# Reconstruction
reconstPCA = inverse_transform(mach, lowInputs) |> DataFrame |> Matrix
# Mean reconstruction error
PCAreconstError = (Matrix(stdInputs) .- reconstPCA) .^ 2 |> mean
0.2862226417791343

3.2 K-medoids

It’s important to remember that supervised and unsupervised techniques are not antagonistic. Each has their uses, pros and cons. In fact, they can be combined. For the rest of this tutorial, the previous 10-dimensional KPCA data projection is used as the input to the models.

Now the unsupervised learning technique K-medoids will be used (H 2024). Its goal is clustering: splitting the data according to groupings of the points. The following steps are required:

  1. The amount of clusters is specified by K.
  2. K data points are chosen as the centers (medoids) of each cluster.
  3. The remaining points are assigned to the cluster of the closest center.
  4. The K medoids are adjusted to minimize the sum of distances from them to all points assigned to their cluster.

K-means is another common clustering algorithm. The main difference to k-medoids is that the centers of clusters don’t need to be data points, being instead defined by the average of all points in a cluster.

mach = machine(KMedoids(; k = 2), lowInputs)
MLJ.fit!(mach, verbosity = 0)
# Predict point assignments
KMpredictions = MLJ.predict(mach, lowInputs)
first(KMpredictions, 5)
5-element CategoricalArrays.CategoricalArray{Int64,1,UInt32}:
 1
 1
 1
 2
 1

3.3 Tangent: hyperparameters

In K-medoids, K is a hyperparameter: a parameter that the model does not learn during training. Rather, this type of parameter is manually set by the user when setting up the model, before training.

From the study’s context of predicting if patients will respond to the treatment or not, here the number K of clusters is set to 2. However, a lot of times, the parameter K in clustering techniques is not known beforehand.

In fact, a common use of clustering is to determine K itself. This translates to revealing if the data points are potentially organized in groups, which tend to be unrecognizable by manual data analysis, specially in high-dimensional spaces.

And to determine hyperparameters, optimization techniques can be applied. For this, usually the dataset will be split between training and validation sets. In such a context, the optimization tends to follow the general procedure in Figure 4. The Tuning Models page in MLJ.jl’s documentation includes examples of strategies.

Figure 4: Usual logic for hyperparameter optimization.

4 Supervised learning

One of the most popular supervised learning algorithms are decision trees. They work by repeatedly splitting the dataset, one feature at a time, by establishing thresholds on the features.

For example, a node may split the patients between those who are younger that 60 years old or not. And, among those younger, another node may do another split between those with cholesterol level lower than 4 mmol/L or not, and so on. Figure 5 illustrates such a decision tree.

Figure 5: Example of decision tree for current case. Does not reflect the models developed.

And, simplifying the problem to only these 2 features, Figure 6 contains a scatter plot showing the splits of the data and the resulting classifications of patients. The splitting continues until the number of samples to which the node refers is too small, or the samples are concentrated in a specific value or interval of the corresponding feature.

Figure 6: Splits and classification from illustrative decision tree. NR - non-responder. Marker colors match respective node in Figure 5.
GLMakie.activate!() # Plotting backend
# Indices of patients in each category
lowYoung = inputs.Age .< 60 .&& inputs.Cholesterol .< 4
highYoung = inputs.Age .< 60 .&& inputs.Cholesterol .>= 4
older = inputs.Age .>= 60
fig = Figure(; fontsize = 18) # Figure to draw in
ax = Axis(fig[1, 1])
vlines!(ax, 60; color = :brown) # Age split
age = inputs[:, :Age]
# Function for relative position in x axis
ageRel(x) = (x - minimum(age)) / (maximum(age) - minimum(age))
# Cholesterol split
hlines!(ax, 4; xmin = ageRel(38), xmax = ageRel(60.4), color = :blue)
# Scatter plots
plot1 = scatter!( # Non-responders on 1st level
    ax,
    inputs[older, "Age"],
    inputs[older, "Cholesterol"];
    color = :green,
    marker = :diamond,
    markersize = 13,
)
plot2 = scatter!( # Non-responders on 2nd level
    ax,
    inputs[highYoung, "Age"],
    inputs[highYoung, "Cholesterol"];
    color = :purple,
    marker = :xcross,
    markersize = 13,
)
plot3 = scatter!( # Responders
    ax,
    inputs[lowYoung, "Age"],
    inputs[lowYoung, "Cholesterol"];
    color = :red,
    marker = :pentagon,
    markersize = 13,
)
Legend( # Labels for splits
    fig[1, 2],
    [plot1, plot2, plot3],
    ["NR 1st level", "NR 2nd level", "Responders"],
)
# Save PNG
save("./decisionTreeScatter.svg", fig)

Decision trees are very popular in ML, since they are interpretable, versatile, and scalable. However, they can be too sensitive to small or superficial changes in the dataset. To overcome this, they can be used in ensembles, that combine multiple simpler models into a single more robust learner (Brownlee 2024).

For decision trees, common ensembles are random forests. They contain multiples trees, each one being trained on a random subset of the data. Each tree also uses only a random subset of the features, to avoid building many similar trees, in case there are a few features with high predictive power. Usually, random forests use a voting system among trees for classification; and average the trees’ outputs for regression.

4.1 Gradient boosted trees

Another type of ensemble is called boosting (Winston 2014): training the weak learners sequentially, each one focusing on cases that posed a challenge to the previous model.

In gradient boosting (Starmer 2019) specifically, the first model \(M_1\) is trained to map features \(X\) to labels \(Y\), so \(M_1: X \rightarrow Y\) as usual. Then, the second model \(M_2\) is instead trained to map the residuals of the first one (\(r = M_1(X) - Y\)) to the labels: \(M_{2}: r \rightarrow Y\). A third model would do the same to the second one, and so on. Lastly, the final prediction made by the ensemble is the sum of all these intermediate outputs.

Such a procedure equates to stacking consecutive corrections to the initial prediction. This ensemble is used below as the base in the TunedModel MLJ.jl utility, which provides workflows for hyperparameter optimization.

GBoostModel = EvoTreeClassifier() # Base model
# Target hyperparameters and their ranges
paramRanges = [
    range(GBoostModel, :max_depth, values = 3:6),
    range(GBoostModel, :nrounds, values = 10:5:100),
]
# Model wrapper for optimization
selfTuningGBoost = TunedModel(
    model = GBoostModel,
    resampling = CV(),
    tuning = Grid(shuffle=false),
    range = paramRanges,
    measure = Accuracy(),
    train_best = false,
)
mach = machine(selfTuningGBoost, lowInputs, outputs)
MLJ.fit!(mach, verbosity = 0);

If we set train_best=true (the default) we can use mach to make new predictions, as for any other machine. In that case the predictions are based on training the best model on all the provided data (all rows of lowInput and outputs). We can also extract the best model:

bestGBoost = fitted_params(mach).best_model
EvoTreeClassifier(
  loss = :mlogloss, 
  metric = :mlogloss, 
  nrounds = 20, 
  bagging_size = 1, 
  early_stopping_rounds = 9223372036854775807, 
  L2 = 1.0, 
  lambda = 0.0, 
  gamma = 0.0, 
  eta = 0.1, 
  max_depth = 4, 
  min_weight = 1.0, 
  rowsample = 1.0, 
  colsample = 1.0, 
  nbins = 64, 
  alpha = 0.5, 
  tree_type = :binary, 
  rng = MersenneTwister(123, (0, 3006, 2004, 906)), 
  device = :cpu)

And evaluate its performance on the whole dataset:

MLJ.evaluate(bestGBoost, lowInputs, outputs; measure = Accuracy(), verbosity = 0)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌────────────┬──────────────┬─────────────┐
│ measure    │ operation    │ measurement │
├────────────┼──────────────┼─────────────┤
│ Accuracy() │ predict_mode │ 0.742       │
└────────────┴──────────────┴─────────────┘
┌────────────────────────────────────┬─────────┐
│ per_fold                           │ 1.96*SE │
├────────────────────────────────────┼─────────┤
│ [0.727, 0.727, 0.7, 0.6, 0.9, 0.8] │ 0.0883  │
└────────────────────────────────────┴─────────┘

In this last code block, evaluate returns a PerformanceEvaluation object, which created the output shown. The per_fold field includes the average values of the measure in the validations for different data splits. The measure used was accuracy, the rate of correct predictions. The “SE” in the 1.96*SE column refers to the standard error of these accuracy estimates. Under a number of assumptions (which generally apply only imprecisely), the number in this column is half the width of a 95% confidence interval field for the performance estimate, centered on the aggregated value, which is reported in the measurement field. This can be helpful when comparing models, but should be used with caution. The aggregated weighting takes the individual validation split sizes into account. Lastly, it’s worth noting the variation that data splitting causes on the validation accuracy. Ideally, the model’s performance shouldn’t be too sensitive to data splitting, resulting in consistent results across folds.

Going back to the TunedModel definition, another setting is the resampling strategy in resampling = CV(). Back in Section 3.3, the reason for splitting the data between training and validation is to avoid overfitting (Google 2024): adjusting the model too tightly to the data distribution.

It’s important to avoid such a phenomenon, because data includes a (hopefully not dominant) random component, which is not informative of the dynamic the data represents. And, if the model gets to the point of “memorizing” the training data, it won’t generalize well, resulting in a considerable performance drop when facing new cases.

Hence it’s standard practice in ML to split the data between training and validation (and sometimes also for testing), so one tracks not only the performance improvement during training, but also the validation behavior. There should be an initial phase during which both training and validation metrics improve. Afterwards however, validation performance starts dropping, indicating overfitting. Figure 7 shows the performance trends in the training and validation splits for different gradient boost ensembles.

Figure 7: Performance trends in training and validation splits when overfitting.
# Split data
trainIndices, valIndices = partition(1:nrow(lowInputs), 0.6)
# Instantiate base modes
model = EvoTreeClassifier()
train_test_sets = [(trainIndices, trainIndices), (trainIndices, valIndices)]
curves = map(train_test_sets) do set
    learning_curve(
        model,
        lowInputs,
        outputs;
        resolution=200,
        resampling=[set,],
        measure=Accuracy(),
        range=range(model, :nrounds, lower=1, upper=400),
    )
end
# Plotting
fig = Figure(; fontsize = 18)
ax = Axis(fig[1, 1], xlabel = "Number of trees", ylabel = "Accuracy")
x = curves[1].parameter_values;
y_train = curves[1].measurements;
y_val = curves[2].measurements;
lTrain = lines!(x, y_train)
lVal = lines!(x, y_val)
Legend( # Labels for splits
    fig[1, 2],
    [lTrain, lVal],
    ["Training", "Validation"],
)
save("./overfit.svg", fig)

The y axis indicates the accuracy: the proportion of predictions that are correct. The higher this value, the better, of course. And the x axis indicates the number of trees included in the ensemble. Increasing this number expands the model’s capacity (Ian Goodfellow 2016): the diversity of mappings it can represent. If it’s low, the model underfits the training data and doesn’t perform as well as it could. If it’s too high, the model has enough flexibility to fit even the noise in the training set. Here, overfitting is identified by the validation performance reaching a peak, and then diminishing again.

Further, aside from manually splitting the dataset (“holdout validation”), a common strategy in the context of hyperparameter search is cross-validation (CV), which Figure 8 illustrates:

  1. Split the data in \(k\) ‘folds’ of equal size.
  2. Iterate on the \(k\) folds:
    1. Assign the fold for validation.
    2. Set the model hyperparameters.
    3. Create a new instance of the model.
    4. Train the current model instance on the training folds.
    5. Validate the current model instance on the validation fold.
    6. Log the performance
  3. Use a strategy to choose the hyperparameters according to the validation errors.
Figure 8: Usage of data splits across iterations of cross-validation.

K-fold cross-validation is sometimes also used without changing hyperparameters between iterations. That’s useful to get a broader perspective on the model’s performance, averaging out the impact of what data points are used in training. This is done in MLJ’s evaluate() function.

Furthermore, regarding the number of folds K (Kohavi 2001), values from 5 to 10 are common. Higher values increase the computational cost of the procedure, but reduce the bias when estimating the model’s performance. Lower values will use less data for training, so small models may be adequate, but large models may overfit.

Going back to the gradient boosted TunedModel, after resampling = CV(), there is also the tuning = Grid() setting, specifying the strategy used to update the hyperparameters across iterations. Grid is the most simple option in MLJ.jl, simply testing a certain number of equally spaced values. This strategy quickly becomes underwhelming in demanding applications, since it’s mostly a trial-and-error exploration of a few possibilities, instead of an efficient search algorithm (Run:ai 2024).

Those values are specified by the next setting, range = paramRanges. It specifies the parameters to be optimized, and what range of values should be considered. Here, the max_depth parameter was optimized in the range 3:6, controlling the maximum depth allowed for the decision trees that underlie the gradient boosting ensemble. Additionally, the parameter nrounds was optimized in the range 10:5:100, determining the number of trees.

After fitting the TunedModel, a model with the optimal hyperparameters can be extracted and used. Below, such a model is fitted to the training split of manually partitioned data. Later on, it will be used to make predictions on the validation set.

# Manual 60:40 data split
trainIndices, valIndices = partition(1:nrow(lowInputs), 0.6)
# Bind model to training split, then fit
mach = machine(bestGBoost, lowInputs[trainIndices, :], outputs[trainIndices])
MLJ.fit!(mach, verbosity = 0);

4.2 Model evaluation

In ML, there are usually numerous options to perform any step in the project, from data preparation to performance evaluation. As an example, there are various metrics for the current case of binary classification.

The confusion matrix enables a handful of useful options. It is built by, after making predictions with the model, arranging in a matrix how many cases fall in each of the following categories:

  1. True positive (TP): a positive case (responder) that was correctly classified as positive.
  2. True negative (TN): a negative case (non-responder) that was correctly classified as negative.
  3. False positive (FP, type I error): a negative case that was misclassified as positive.
  4. False negative (FN, type II error): a positive case that was misclassified as negative.
predictionsGBoost = predict_mode(mach, lowInputs[valIndices, :])
# Confusion matrix [TN FN; FP TP]
cmGBoost = confmat(predictionsGBoost, outputs[valIndices])
          ┌─────────────┐
          │Ground Truth │
┌─────────┼──────┬──────┤
│Predicted│ 0.0  │ 1.0  │
├─────────┼──────┼──────┤
│   0.0   │  7   │  2   │
├─────────┼──────┼──────┤
│   1.0   │  5   │  11  │
└─────────┴──────┴──────┘

From these basic definitions, the following binary classification metrics can be calculated to prioritize different aspects of performance:

  1. Precision, positive predictive value (PPV): rate of correct positive classifications.

\[\frac{TP}{TP + FP}\]

  1. Recall, sensitivity or true positive rate: ratio of positive cases that are correctly classified.

\[\frac{TP}{TP + FN}\]

  1. F1 score: harmonic mean between precision and recall. It is only high if both previous metrics are high.

\[2 * \frac{(Precision * Recall)}{Precision + Recall}\]

  1. Accuracy: fraction of classifications that are correct.

\[\frac{TP + TN}{TP + TN + FP + FN}\]

We can extract these metrics directly from the confusion matrix, or obtain them using evaluate as shown below. In the results, it’s worth pointing out the drop in accuracy in comparison to the measurement field in the previous call to MLJ.evaluate(bestGBoost, lowInputs, outputs...). This highlights one of the consequences of small datasets: the impact of data splitting. The impact can also be noticed by the variation in performance across folds in the previous MLJ.evaluate call (Section 4.1).

# Metrics
ppv(cmGBoost)      # `ppv` is an alias for `PositivePredictiveRate()`
recall(cmGBoost)   # `recall` is an alias for `TruePositiveRate()`
f1score(cmGBoost)  # `f1score` is an alias for `FScore()`
accuracy(cmGBoost)
# The same metrics using `evaluate`:
MLJ.evaluate(
    bestGBoost,
    lowInputs,
    outputs;
    resampling=[(trainIndices, valIndices),],
    measures = [ppv, recall, f1score, accuracy, cross_entropy],
)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌───┬──────────────────────────┬──────────────┬─────────────┐
│   │ measure                  │ operation    │ measurement │
├───┼──────────────────────────┼──────────────┼─────────────┤
│ A │ PositivePredictiveValue( │ predict_mode │ 0.688       │
│   │   levels = nothing,      │              │             │
│   │   rev = nothing,         │              │             │
│   │   checks = true)         │              │             │
│ B │ TruePositiveRate(        │ predict_mode │ 0.846       │
│   │   levels = nothing,      │              │             │
│   │   rev = nothing,         │              │             │
│   │   checks = true)         │              │             │
│ C │ FScore(                  │ predict_mode │ 0.759       │
│   │   beta = 1.0,            │              │             │
│   │   levels = nothing,      │              │             │
│   │   rev = nothing,         │              │             │
│   │   checks = true)         │              │             │
│ D │ Accuracy()               │ predict_mode │ 0.72        │
│ E │ LogLoss(                 │ predict      │ 0.559       │
│   │   tol = 2.22045e-16)     │              │             │
└───┴──────────────────────────┴──────────────┴─────────────┘

In the evaluate call in the last code block, the metric cross_entropy was provided as well. It refers to binary cross-entropy (Godoy 2018), a measure of dissimilarity between two distributions. After fitting, the outputs of the model should follow the same distribution of the original data. Therefore, binary cross-entropy can be used as an objective to be minimized for training binary classification models, leading to the same results as if maximizing the log-likelihood (Draelos 2019). In the following formula, \(N\) is the amount of binary variables being compared; and variable \(i\) assumes the value \(y_i\) with probability \(p(y_i)\).

\[CE = -\frac{1}{N} \sum^{N}_{i = 1} y_i \cdot log[p(y_i)] + (1 - y_i) \cdot log[1 - p(y_i)]\]

For documentation on performance measures (metrics) available in MLJ, refer to the StatisticalMeasures.jl docs.

4.3 K-nearest neighbors pipeline

The next model to be used is K-nearest neighbors (elastic 2024). It is a supervised approach that compares the input against the K nearest data points. In regression problems, it then averages the labels of these neighbors, giving greater weight to those that are closest; and, in classification, it outputs probabilities for the classes with weightings based on proximity of the neighbors with a given class. Here, the parameter K defines how many neighboring points will be considered, and its default is 5.

In this case we will use an MLJ pipeline, a functionality that enables us to concisely combine multiple elements of our workflow into a single model. Starting with missing value imputation, the first step of our workflow that involved genuine “learning”, we combine previous steps into a pipeline, with a k-nearest neighbors model added at the end.

# Pipeline model
modelPipe =
    FillImputer |>
    Standardizer |>
    PCA(; maxoutdim = 10) |>
    knn(; K = 3)
ProbabilisticPipeline(
  fill_imputer = FillImputer(
        features = Symbol[], 
        continuous_fill = MLJModels._median, 
        count_fill = MLJModels._round_median, 
        finite_fill = MLJModels._mode), 
  standardizer = Standardizer(
        features = Symbol[], 
        ignore = false, 
        ordered_factor = false, 
        count = false), 
  pca = PCA(
        maxoutdim = 10, 
        method = :auto, 
        variance_ratio = 0.99, 
        mean = nothing), 
  knn_classifier = KNNClassifier(
        K = 3, 
        algorithm = :kdtree, 
        metric = Euclidean(0.0), 
        leafsize = 10, 
        reorder = true, 
        weights = NearestNeighborModels.Uniform()), 
  cache = true)
Note

The A |> B operator in Julia is called piping, which uses the output of expression A as the input to expression B, clearly indicating information flow.

Let’s train this model using 70% of the observations:

# Split the data into train and validation:
train, validation = partition(1:length(outputs), 0.7)
mach = machine(modelPipe, inputs[train, :], outputs[train])
fit!(mach, verbosity = 0);

The model’s “machine”, mach, can be saved to disk with MLJ.save("pipeModel.jld2", mach). It can then be loaded later, for example when starting a new session, for a future continuation of the project, building a model catalogue etc. Together with pipelines, in case of a new file with additional data, the model can be loaded with machLoaded = machine("./pipeModel.jld2"), and used to quickly reapply the entire workflow discussed so far with pipePredictions = MLJ.predict(machLoaded, df[:, colIDs]), directly on raw data, that has missing values and isn’t preprocessed.

Furthermore, with this model, the output for each case is a probability distribution across the two possibilities. Here, this refers to the chance of the patient responding to the treatment. Therefore, to assess the model’s performance, the area under the receiver operator curve (AUC) can be applied.

In short, when the model predicts a certain chance that a patient will respond to the treatment, a threshold (60% or 70% etc.) needs to be established, above which it is interpreted that the model considers the patient a responder. However, the threshold level impacts the metrics in the confusion matrix (Section 4.2), so the best value for a given context must be chosen.

The receiver operating characteristic (ROC) curve is commonly used for this. It plots the true positive rate (i.e. TPR, recall, sensitivity) in the y axis, and the false positive rate (F.P.R.) in the x axis. And, for each threshold level, the model’s predictions result in a (F.P.R., TPR) pair, representing a point in the plot and generating a curve.

A model that randomly guesses the outcome (equivalent to a coin flip), is graphically represented by a 45° line. And, ideally, the curve that results from the model isn’t only above such a straight line, but also approaches the top left corner, of unitary TPR and null F.P.R.. This can be measured through the AUC, which would then approach 1.

Finally, for the previous pipeline model, the ROC on the validation split is included in Figure 9. The ordinates and coordinates of the points on this curve are given by false_positive_rates and true_positive_rates in the computation below.

pipePredictions = MLJ.predict(mach, inputs[validation, :])
auc(pipePredictions, outputs[validation])
0.8571428571428571
false_positive_rates, true_positive_rates, thresholds =
    roc_curve(pipePredictions, outputs[validation])
thresholds
4-element Vector{Float64}:
 1.0
 0.6666666666666666
 0.3333333333333333
 0.0
Figure 9: ROC for kNN pipeline model.
fig = Figure(; fontsize = 18) # Create figure to draw in
axROC = Axis(
    fig[1, 1];
    title = "Receiver operating characteristic (ROC) curve",
    xlabel = "False positive rate",
    ylabel = "True positive rate",
    xticks = 0:0.2:1,
    yticks = 0:0.2:1,
)
l45 = lines!(axROC, [0, 1], [0, 1]; linestyle = :dash) # 45° line
lModel = lines!(axROC, false_positive_rates, true_positive_rates) # Model
Legend(fig[1, 2], [l45, lModel], ["Random guess", "kNN"])
save("./roc.svg", fig) # Save plot

5 References

Airbyte. 2024. “What Is Data Imputation: Purpose, Techniques, & Methods.” 2024. https://airbyte.com/data-engineering-resources/data-imputation.
Bolstad, B. M., R. A Irizarry, M. Åstrand, and T. P. Speed. 2003. “A Comparison of Normalization Methods for High Density Oligonucleotide Array Data Based on Variance and Bias.” Bioinformatics 19 (2): 185–93. https://doi.org/10.1093/bioinformatics/19.2.185.
Brownlee, Jason. 2023. “A Tour of Machine Learning Algorithms.” 2023. https://machinelearningmastery.com/a-tour-of-machine-learning-algorithms/.
———. 2024. “A Gentle Introduction to Ensemble Diversity for Machine Learning.” 2024. https://machinelearningmastery.com/ensemble-diversity-for-machine-learning/.
Draelos, Rachel. 2019. “Connections: Log Likelihood, Cross Entropy, KL Divergence, Logistic Regression, and Neural Networks.” 2019. https://glassboxmedicine.com/2019/12/07/connections-log-likelihood-cross-entropy-kl-divergence-logistic-regression-and-neural-networks/.
elastic. 2024. “What Is kNN?” 2024. https://www.elastic.co/what-is/knn.
Géron, Aurélien. 2019. Hands-on Machine Learning with Scikit-Learn, Keras, and TensorFlow. O’Reilly.
Godoy, Daniel. 2018. “Understanding Binary Cross-Entropy / Log Loss: A Visual Explanation.” 2018. https://towardsdatascience.com/understanding-binary-cross-entropy-log-loss-a-visual-explanation-a3ac6025181a.
Google. 2024. “Overfitting.” 2024. https://developers.google.com/machine-learning/crash-course/overfitting/overfitting.
H, Prasan N. 2024. “Exploring the World of Clustering: K-Means Vs. K-Medoids.” 2024. https://medium.com/@prasanNH/exploring-the-world-of-clustering-k-means-vs-k-medoids-f648ea738508.
Ian Goodfellow, Aaron Courville, Yoshua Bengio. 2016. “Deep Learning.” 2016. https://www.deeplearningbook.org/.
Jaadi, Zakaria. 2024. “Principal Component Analysis (PCA): A Step-by-Step Explanation.” 2024. https://builtin.com/data-science/step-step-explanation-principal-component-analysis.
Jacob Murel, Eda Kavlakoglu. 2024. “What Is Dimensionality Reduction?” 2024. https://www.ibm.com/topics/dimensionality-reduction.
Kohavi, Ron. 2001. “A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection” 14 (March).
Madaric, Juraj, Andrej Klepanec, Martina Valachovicova, Martin Mistrik, Maria Bucova, Ingrid Olejarova, Roman Necpal, Terezia Madaricova, Ludovit Paulis, and Ivan Vulev. 2016. “Characteristics of Responders to Autologous Bone Marrow Cell Therapy for No-Option Critical Limb Ischemia.” Stem Cell Research & Therapy 7 (1): 116. https://doi.org/10.1186/s13287-016-0379-z.
Mehreen, Kanwal. 2022. “Top 5 Machine Learning Practices Recommended by Experts.” 2022. https://www.kdnuggets.com/2022/09/top-5-machine-learning-practices-recommended-experts.html.
Run:ai. 2024. “Hyperparameter Tuning.” 2024. https://www.run.ai/guides/hyperparameter-tuning.
Starmer, Josh. 2019. “XGBoost Part 1 (of 4): Regression.” 2019. https://www.youtube.com/watch?v=OtD8wVaFm6E.
Thompson, David. 2018. “8.6 David Thompson (Part 6): Nonlinear Dimensionality Reduction: KPCA.” 2018. https://www.youtube.com/watch?v=HbDHohXPLnU.
Winston, Patrick. 2014. “17. Learning: Boosting.” 2014. https://www.youtube.com/watch?v=UHBmv7qCey4.

Reuse